device_moe_mx_gemm.hpp Source File

device_moe_mx_gemm.hpp Source File#

Composable Kernel: device_moe_mx_gemm.hpp Source File
device_moe_mx_gemm.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24template <typename ALayout,
25 typename BLayout,
26 typename DsLayout,
27 typename CLayout,
28 typename ADataType,
29 typename AScaleDataType,
30 typename BDataType,
31 typename BScaleDataType,
32 typename DsDataType,
33 typename CDataType,
34 typename GemmAccDataType,
35 typename CShuffleDataType,
36 typename AElementwiseOperation,
37 typename BElementwiseOperation,
38 typename CElementwiseOperation,
39 GemmSpecialization GemmSpec,
40 index_t ScaleBlockSize,
41 index_t BlockSize,
42 index_t MPerBlock,
43 index_t NPerBlock,
44 index_t KPerBlock,
45 index_t AK1,
46 index_t BK1,
47 index_t MPerXDL,
48 index_t NPerXDL,
49 index_t MXdlPerWave,
50 index_t NXdlPerWave,
51 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
52 typename ABlockTransferThreadClusterArrangeOrder,
53 typename ABlockTransferSrcAccessOrder,
54 index_t ABlockTransferSrcVectorDim,
55 index_t ABlockTransferSrcScalarPerVector,
56 index_t ABlockTransferDstScalarPerVector_AK1,
57 bool ABlockLdsExtraM,
58 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
59 typename BBlockTransferThreadClusterArrangeOrder,
60 typename BBlockTransferSrcAccessOrder,
61 index_t BBlockTransferSrcVectorDim,
62 index_t BBlockTransferSrcScalarPerVector,
63 index_t BBlockTransferDstScalarPerVector_BK1,
64 bool BBlockLdsExtraN,
65 index_t CShuffleMXdlPerWavePerShuffle,
66 index_t CShuffleNXdlPerWavePerShuffle,
67 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
68 typename CDEShuffleBlockTransferScalarPerVectors,
71 index_t ActivationOP = 0,
72 bool NSwizzle = false,
73 bool IsInputGemm = true,
74 bool MulRoutedWeight = true,
75 typename IndexType = index_t,
76 typename ComputeTypeA = ADataType,
77 typename ComputeTypeB = BDataType>
79 BLayout,
80 DsLayout,
81 CLayout,
82 ADataType,
83 AScaleDataType,
84 BDataType,
85 BScaleDataType,
86 DsDataType,
87 CDataType,
88 ScaleBlockSize,
89 AElementwiseOperation,
90 BElementwiseOperation,
91 CElementwiseOperation>
92{
94 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
95 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
96 static constexpr index_t NumDTensor = DsDataType::Size();
97 template <index_t NXdlPerWave_>
99 GridwiseMoeGemmMX<ALayout,
100 BLayout,
101 DsLayout,
102 CLayout,
103 ADataType,
104 AScaleDataType,
105 BDataType,
106 BScaleDataType,
107 GemmAccDataType,
108 CShuffleDataType,
109 DsDataType,
110 CDataType,
111 AElementwiseOperation,
112 BElementwiseOperation,
113 CElementwiseOperation,
114 GemmSpec,
115 ScaleBlockSize,
116 BlockSize,
117 MPerBlock,
118 NPerBlock,
119 KPerBlock,
120 AK1,
121 BK1,
122 MPerXDL,
123 NPerXDL,
124 MXdlPerWave,
125 NXdlPerWave_,
126 ABlockTransferThreadClusterLengths_AK0_M_AK1,
127 ABlockTransferThreadClusterArrangeOrder,
128 ABlockTransferSrcAccessOrder,
129 ABlockTransferSrcVectorDim,
130 ABlockTransferSrcScalarPerVector,
131 ABlockTransferDstScalarPerVector_AK1,
132 false,
133 ABlockLdsExtraM,
134 BBlockTransferThreadClusterLengths_BK0_N_BK1,
135 BBlockTransferThreadClusterArrangeOrder,
136 BBlockTransferSrcAccessOrder,
137 BBlockTransferSrcVectorDim,
138 BBlockTransferSrcScalarPerVector,
139 BBlockTransferDstScalarPerVector_BK1,
140 false,
141 BBlockLdsExtraN,
142 CShuffleMXdlPerWavePerShuffle,
143 CShuffleNXdlPerWavePerShuffle,
144 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
145 CDEShuffleBlockTransferScalarPerVectors,
146 BlkGemmPipeSched,
147 BlkGemmPipelineVer,
148 ActivationOP,
149 NSwizzle,
150 IsInputGemm,
151 MulRoutedWeight,
152 IndexType,
153 ComputeTypeA,
154 ComputeTypeB>;
157
158 using Argument = typename GridwiseGemm64::Argument;
161
162 int GetPreShuffleParameters() override { return NPerXDL; }
163
164 // Invoker
165 struct Invoker : public BaseInvoker
166 {
167 template <typename GridwiseGemm>
168 float RunImp(const typename GridwiseGemm::Argument& arg,
169 const StreamConfig& stream_config = StreamConfig{})
170 {
171 if(stream_config.log_level_ > 0)
172 {
173 arg.Print();
174 }
175
176 if(!GridwiseGemm::CheckValidity(arg))
177 {
178 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
179 }
180
181 index_t gdx, gdy, gdz;
182 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
183
184 float ave_time = 0;
185
186 index_t k_grain = arg.KBatch * KPerBlock;
187 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
188
189 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
190
191 const auto RunKernel = [&](const auto& kernel) {
192 if(stream_config.flush_cache)
193 {
194
195 std::array<std::size_t, NumDTensor> DsSize;
196
197 auto arg_ = arg;
198
199 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
200 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
201 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
202 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
203
204 auto size_a_buffer =
205 a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
206 auto size_b_buffer =
207 b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
208
209 const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
210 arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
211
212 static_for<0, NumDTensor, 1>{}([&](auto i) {
213 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
214 DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
215 });
216 ck::utility::RotatingMemWrapperMultiD<typename GridwiseGemm::Argument,
217 DsDataType>
218 rotating_mem(arg_,
219 stream_config.rotating_count,
220 size_a_buffer,
221 size_b_buffer,
222 DsSize);
223 rotating_mem.Print();
224
225 auto run_flush_cache = [&]() {
226 // flush icache
228 // rotating mem
229 rotating_mem.Next();
230 // clear c mem
231 if(arg_.KBatch > 1)
232 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
233 0,
234 arg_.M * arg_.N * sizeof(CDataType),
235 stream_config.stream_id_));
236 };
237
239 stream_config,
240 run_flush_cache,
241 kernel,
242 dim3(gdx, gdy, gdz),
243 dim3(BlockSize),
244 0,
245 arg_);
246 }
247 else
248 {
249 if(arg.KBatch > 1)
250 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
251 0,
252 arg.M * arg.N * sizeof(CDataType),
253 stream_config.stream_id_));
254
255 ave_time = launch_and_time_kernel(
256 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
257 }
258 };
259
260 // TODO: Check if this is the right algorithm for minimum_occupancy
261 constexpr index_t minimum_occupancy =
262 BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
263 ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
264 MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
265 ? 2
266 : 1
267 : 2;
268
269 constexpr auto MemoryDataOp =
271
272 if(has_main_k_block_loop)
273 {
274 // Tail number always full
275 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
276 {
277 const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
278 true,
279 MemoryDataOp,
280 minimum_occupancy,
282 RunKernel(kernel);
283 }
284 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
285 {
286 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
287 {
288 const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
289 true,
290 MemoryDataOp,
291 minimum_occupancy,
293 RunKernel(kernel);
294 }
295 else
296 {
297 const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
298 true,
299 MemoryDataOp,
300 minimum_occupancy,
302 RunKernel(kernel);
303 }
304 }
305 else
306 {
307 throw std::runtime_error("todo: only v1 & v3 support now");
308 }
309 }
310 else
311 {
312 // Tail number always full
313 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
314 {
315 const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
316 false,
317 MemoryDataOp,
318 minimum_occupancy,
320 RunKernel(kernel);
321 }
322 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
323 {
324 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
325 {
326 const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
327 false,
328 MemoryDataOp,
329 minimum_occupancy,
331 RunKernel(kernel);
332 }
333 else
334 {
335 const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
336 false,
337 MemoryDataOp,
338 minimum_occupancy,
340 RunKernel(kernel);
341 }
342 }
343 }
344
345 return ave_time;
346 }
347
349
350 // polymorphic
351 float Run(const BaseArgument* p_arg,
352 const StreamConfig& stream_config = StreamConfig{}) override
353 {
354 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
355 }
356 };
357
358 static constexpr bool IsValidCompilationParameter()
359 {
360 // TODO: properly implement this check
361 return true;
362 }
363
364 static bool IsSupportedArgument(const Argument& arg)
365 {
366 // only impl kbatch 1 now
367 if(arg.KBatch > 1)
368 {
369 return false;
370 }
372 {
373 return false;
374 }
375 if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
376 {
377 return false;
378 }
379
380 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
381 GemmSpec == GemmSpecialization::NKPadding ||
382 GemmSpec == GemmSpecialization::MNKPadding ||
383 GemmSpec == GemmSpecialization::KPadding))
384 {
385 return false;
386 }
387 if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
388 {
389 return false;
390 }
391
392 if(get_warp_size() == 64)
393 {
394 if constexpr(NXdlPerWave64 > 0)
395 {
397 }
398 }
399 else
400 {
401 if constexpr(NXdlPerWave32 > 0)
402 {
404 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
405 }
406 }
407 return false;
408 }
409
410 // polymorphic
411 bool IsSupportedArgument(const BaseArgument* p_arg) override
412 {
413 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
414 }
415
416 static auto MakeArgument(const void* p_sorted_token_ids,
417 const void* p_sorted_expert_ids,
418 const void* p_max_token_id,
419 const void* p_a,
420 const void* p_a_scale,
421 const void* p_b,
422 const void* p_b_scale,
423 std::array<const void*, NumDTensor> p_ds,
424 void* p_c,
425 index_t NumTokens,
426 index_t TopK,
427 index_t M,
428 index_t N,
429 index_t K,
430 index_t StrideA,
431 index_t StrideScaleA,
432 index_t StrideB,
433 index_t StrideScaleB,
434 std::array<index_t, NumDTensor> StrideDs,
435 index_t StrideC,
436 index_t KBatch,
437 AElementwiseOperation a_element_op,
438 BElementwiseOperation b_element_op,
439 CElementwiseOperation c_element_op)
440 {
441 return Argument{static_cast<const index_t*>(p_sorted_token_ids),
442 static_cast<const index_t*>(p_sorted_expert_ids),
443 static_cast<const index_t*>(p_max_token_id),
444 static_cast<const ADataType*>(p_a),
445 static_cast<const AScaleDataType*>(p_a_scale),
446 static_cast<const BDataType*>(p_b),
447 static_cast<const BScaleDataType*>(p_b_scale),
448 p_ds,
449 static_cast<CDataType*>(p_c),
450 NumTokens,
451 TopK,
452 M,
453 N,
454 K,
455 StrideA,
456 StrideScaleA,
457 StrideB,
458 StrideScaleB,
459 StrideDs,
460 StrideC,
461 KBatch,
462 a_element_op,
463 b_element_op,
464 c_element_op};
465 }
466
467 static auto MakeInvoker() { return Invoker{}; }
468
469 // polymorphic
470 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
471 const void* p_a_scale,
472 const void* p_b,
473 const void* p_b_scale,
474 std::array<const void*, NumDTensor> p_ds,
475 void* p_c,
476 index_t M,
477 index_t N,
478 index_t K,
479 index_t StrideA,
480 index_t StrideScaleA,
481 index_t StrideB,
482 index_t StrideScaleB,
483 std::array<ck::index_t, NumDTensor> StrideDs,
484 index_t StrideC,
485 index_t KBatch,
486 AElementwiseOperation a_element_op,
487 BElementwiseOperation b_element_op,
488 CElementwiseOperation c_element_op) override
489 {
490 return std::make_unique<Argument>(nullptr,
491 nullptr,
492 nullptr,
493 static_cast<const ADataType*>(p_a),
494 static_cast<const AScaleDataType*>(p_a_scale),
495 static_cast<const BDataType*>(p_b),
496 static_cast<const BScaleDataType*>(p_b_scale),
497 p_ds,
498 static_cast<CDataType*>(p_c),
499 M, // randoms set, no use
500 0,
501 M,
502 N,
503 K,
504 StrideA,
505 StrideScaleA,
506 StrideB,
507 StrideScaleB,
508 StrideDs,
509 StrideC,
510 KBatch,
511 a_element_op,
512 b_element_op,
513 c_element_op);
514 }
515
516 // polymorphic
517 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
518 {
519 return std::make_unique<Invoker>(Invoker{});
520 }
521
522 // polymorphic
523 std::string GetTypeString() const override
524 {
525 auto str = std::stringstream();
526
527 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
530
531 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
537
538 // clang-format off
539 str << "DeviceMoeGEmmMx"
540 << "<"
541 << getGemmSpecializationString(GemmSpec) << ", "
542 << std::string(ALayout::name)[0]
543 << std::string(BLayout::name)[0]
544 << std::string(CLayout::name)[0]
545 << ">"
546 << " BlkSize: "
547 << BlockSize << ", "
548 << "BlkTile: "
549 << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
550 << "WaveTile: "
551 << MPerXDL<<"x"<<NPerXDL << ", "
552 << "WaveMap: "
553 << MXdlPerWave<<"x" << NXdlPerWave<<", "
554 << "VmemReadVec: "
555 << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
556 << "BlkGemmPipelineScheduler: "
557 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
558 << "BlkGemmPipelineVersion: "
559 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
560 << "BlkGemmPipelinePrefetchStages: "
561 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
562 // clang-format on
563
564 return str.str();
565 }
566};
567
568} // namespace device
569} // namespace tensor_operation
570} // namespace ck
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_moe_mx_gemm.hpp:90
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
constexpr index_t packed_size_v
Definition data_type.hpp:411
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition gridwise_moe_mx_gemm.hpp:179
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_gemm_multiple_d.hpp:167
Definition device_moe_mx_gemm.hpp:166
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_moe_mx_gemm.hpp:351
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_moe_mx_gemm.hpp:168
Definition device_moe_mx_gemm.hpp:92
typename GridwiseGemm64::Argument Argument
Definition device_moe_mx_gemm.hpp:158
std::string GetTypeString() const override
Definition device_moe_mx_gemm.hpp:523
static constexpr bool IsValidCompilationParameter()
Definition device_moe_mx_gemm.hpp:358
static constexpr index_t APackedSize
Definition device_moe_mx_gemm.hpp:159
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_moe_mx_gemm.hpp:155
static bool IsSupportedArgument(const Argument &arg)
Definition device_moe_mx_gemm.hpp:364
GridwiseMoeGemmMX< ALayout, BLayout, DsLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ActivationOP, NSwizzle, IsInputGemm, MulRoutedWeight, IndexType, ComputeTypeA, ComputeTypeB > GridwiseGemmBase
Definition device_moe_mx_gemm.hpp:98
static constexpr index_t NumDTensor
Definition device_moe_mx_gemm.hpp:96
static auto MakeArgument(const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_moe_mx_gemm.hpp:416
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_moe_mx_gemm.hpp:94
int GetPreShuffleParameters() override
Definition device_moe_mx_gemm.hpp:162
static constexpr index_t BPackedSize
Definition device_moe_mx_gemm.hpp:160
static constexpr auto NXdlPerWave32
Definition device_moe_mx_gemm.hpp:95
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_moe_mx_gemm.hpp:156
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_moe_mx_gemm.hpp:411
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_moe_mx_gemm.hpp:470
static auto MakeInvoker()
Definition device_moe_mx_gemm.hpp:467
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_moe_mx_gemm.hpp:517
Definition flush_cache.hpp:174