device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp Source File

device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp Source File#

Composable Kernel: device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp Source File
device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
18
19namespace ck {
20namespace tensor_operation {
21namespace device {
22
23template <typename ADataType,
24 typename BDataType,
25 typename CDataType,
26 typename AccDataType,
27 typename ALayout,
28 typename BLayout,
29 typename CLayout,
30 typename AElementwiseOperation,
31 typename BElementwiseOperation,
32 typename CElementwiseOperation,
33 GemmSpecialization GemmSpec,
34 ck::index_t NumGemmKPrefetchStage,
35 ck::index_t BlockSize,
36 ck::index_t MPerBlock,
37 ck::index_t NPerBlock,
38 ck::index_t K0PerBlock,
39 ck::index_t K1,
40 ck::index_t MPerXDL,
41 ck::index_t NPerXDL,
42 ck::index_t MXdlPerWave,
43 ck::index_t NXdlPerWave,
44 typename ABlockTransferThreadClusterLengths_K0_M_K1,
45 typename ABlockTransferSrcAccessOrder,
46 ck::index_t ABlockTransferSrcVectorDim,
47 ck::index_t ABlockTransferScalarPerVector,
48 bool ABlockLdsAddExtraM,
49 typename BBlockTransferThreadClusterLengths_K0_N_K1,
50 typename BBlockTransferSrcAccessOrder,
51 ck::index_t BBlockTransferSrcVectorDim,
52 ck::index_t BBlockTransferScalarPerVector,
53 bool BBlockLdsAddExtraN,
54 index_t CShuffleMRepeatPerShuffle,
55 index_t CShuffleNRepeatPerShuffle,
56 typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
57 index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
58 typename ComputeType = CDataType,
61
63 BLayout,
64 CLayout,
65 ADataType,
66 BDataType,
67 CDataType,
68 AElementwiseOperation,
69 BElementwiseOperation,
70 CElementwiseOperation,
71 ComputeType>
72{
74 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
75 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
76
77 static constexpr auto I0 = Number<0>{};
78 static constexpr auto I1 = Number<1>{};
79 static constexpr auto I2 = Number<2>{};
80 static constexpr auto I3 = Number<3>{};
81
82 template <index_t NXdlPerWave_>
84 BlockSize,
85 ADataType,
86 BDataType,
87 AccDataType,
88 CDataType,
89 ALayout,
90 BLayout,
91 CLayout,
92 AElementwiseOperation,
93 BElementwiseOperation,
94 CElementwiseOperation,
95 GemmSpec,
96 NumGemmKPrefetchStage,
97 MPerBlock,
98 NPerBlock,
99 K0PerBlock,
100 MPerXDL,
101 NPerXDL,
102 K1,
103 MXdlPerWave,
104 NXdlPerWave_,
105 ABlockTransferThreadClusterLengths_K0_M_K1,
106 ABlockTransferSrcAccessOrder,
107 ABlockTransferSrcVectorDim,
108 ABlockTransferScalarPerVector,
109 ABlockLdsAddExtraM,
110 BBlockTransferThreadClusterLengths_K0_N_K1,
111 BBlockTransferSrcAccessOrder,
112 BBlockTransferSrcVectorDim,
113 BBlockTransferScalarPerVector,
114 BBlockLdsAddExtraN,
115 CShuffleMRepeatPerShuffle,
116 CShuffleNRepeatPerShuffle,
117 CBlockTransferScalarPerVector_NWaveNPerXDL,
118 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
119 LoopSched,
120 PipelineVer,
121 ComputeType>;
124
125 struct Argument : public GridwiseGemm64::Argument
126 {
127 Argument(const ADataType* p_a_grid_,
128 const BDataType* p_b_grid_,
129 CDataType* p_c_grid_,
130 index_t M_,
131 index_t N_,
132 index_t K_,
133 index_t StrideA_,
134 index_t StrideB_,
135 index_t StrideC_,
136 index_t MPadded_,
137 index_t NPadded_,
138 index_t KPadded_,
139 index_t K0Padded_,
140 index_t k_batch_,
141 AElementwiseOperation a_element_op_,
142 BElementwiseOperation b_element_op_,
143 CElementwiseOperation c_element_op_)
144 : GridwiseGemm64::Argument(p_a_grid_,
145 p_b_grid_,
146 p_c_grid_,
147 M_,
148 N_,
149 K_,
150 StrideA_,
151 StrideB_,
152 StrideC_,
153 MPadded_,
154 NPadded_,
155 KPadded_,
156 K0Padded_,
157 k_batch_),
158 a_element_op(a_element_op_),
159 b_element_op(b_element_op_),
160 c_element_op(c_element_op_)
161 {
162 }
163
164 AElementwiseOperation a_element_op;
165 BElementwiseOperation b_element_op;
166 CElementwiseOperation c_element_op;
167 };
168
170
171 // Invoker
172 struct Invoker : public BaseInvoker
173 {
174 template <typename Argument_>
175 void Print(const Argument_& karg)
176 {
177 karg.Print();
178 }
179
180 template <typename GridwiseGemm>
181 float RunImp(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
182 {
183 if(stream_config.log_level_ > 0)
184 {
185 Print(karg);
186 }
187
188 const auto kbatch = karg.k_batch;
189 auto arg = *reinterpret_cast<const typename GridwiseGemm::Argument*>(&karg);
190 if(!GridwiseGemm::CheckValidity(arg))
191 {
192 throw std::runtime_error(
193 "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid "
194 "setting");
195 }
196
197 const auto b2c_map = DefaultBlock2CTileMap{};
198 index_t gdx, gdy, gdz;
199 ck::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch);
200 const auto K0Padded = karg.K0Padded;
201
202 const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0Padded);
203
204 float ave_time = 0;
205
206 const auto Run = [&](const auto& kernel) {
207 if(kbatch > 1)
208 hipGetErrorString(hipMemsetAsync(karg.p_c_grid,
209 0,
210 karg.M * karg.N * sizeof(CDataType),
211 stream_config.stream_id_));
212
213 ave_time = launch_and_time_kernel(stream_config,
214 kernel,
215 dim3(gdx, gdy, gdz),
216 dim3(BlockSize),
217 0,
218 arg,
219 b2c_map,
220 karg.a_element_op,
221 karg.b_element_op,
222 karg.c_element_op);
223 };
224
225 if(has_main_k0_block_loop)
226 {
227 if(kbatch == 1)
228 {
229 const auto kernel =
231 true,
234 AElementwiseOperation,
235 BElementwiseOperation,
236 CElementwiseOperation>;
237
238 Run(kernel);
239 }
240 else
241 {
243 GridwiseGemm,
244 true,
247 AElementwiseOperation,
248 BElementwiseOperation,
249 CElementwiseOperation>;
250
251 Run(kernel);
252 }
253 }
254 else
255 {
256 if(kbatch == 1)
257 {
258 const auto kernel =
260 false,
263 AElementwiseOperation,
264 BElementwiseOperation,
265 CElementwiseOperation>;
266
267 Run(kernel);
268 }
269 else
270 {
272 GridwiseGemm,
273 false,
276 AElementwiseOperation,
277 BElementwiseOperation,
278 CElementwiseOperation>;
279
280 Run(kernel);
281 }
282 }
283
284 return ave_time;
285 }
286
288
289 // polymorphic
290 float Run(const BaseArgument* p_arg,
291 const StreamConfig& stream_config = StreamConfig{}) override
292 {
293 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
294 }
295 };
296
297 static constexpr bool IsValidCompilationParameter()
298 {
299 // TODO: properly implement this check
300 return true;
301 }
302
303 static bool IsSupportedArgument(const Argument& karg)
304 {
306 {
307 return false;
308 }
310 {
311 return false;
312 }
313
314 if(get_warp_size() == 64)
315 {
316 if constexpr(NXdlPerWave64 > 0)
317 {
319 }
320 }
321 else
322 {
323 if constexpr(NXdlPerWave32 > 0)
324 {
326 reinterpret_cast<const typename GridwiseGemm32::Argument&>(karg));
327 }
328 }
329 return false;
330 }
331
332 // polymorphic
333 bool IsSupportedArgument(const BaseArgument* p_arg) override
334 {
335 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
336 }
337
338 static auto MakeArgument(const ADataType* p_a,
339 const BDataType* p_b,
340 CDataType* p_c,
341 index_t M,
342 index_t N,
343 index_t K,
344 index_t StrideA,
345 index_t StrideB,
346 index_t StrideC,
347 AElementwiseOperation a_element_op,
348 BElementwiseOperation b_element_op,
349 CElementwiseOperation c_element_op,
350 index_t KBatch)
351 {
352 return Argument(p_a,
353 p_b,
354 p_c,
355 M,
356 N,
357 K,
358 StrideA,
359 StrideB,
360 StrideC,
365 KBatch,
366 a_element_op,
367 b_element_op,
368 c_element_op);
369 }
370
371 static auto MakeInvoker() { return Invoker{}; }
372
373 // polymorphic
374 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
375 const void* p_b,
376 void* p_c,
377 index_t M,
378 index_t N,
379 index_t K,
380 index_t StrideA,
381 index_t StrideB,
382 index_t StrideC,
383 AElementwiseOperation a_element_op,
384 BElementwiseOperation b_element_op,
385 CElementwiseOperation c_element_op,
386 ck::index_t KBatch = 1) override
387 {
388 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
389 static_cast<const BDataType*>(p_b),
390 static_cast<CDataType*>(p_c),
391 M,
392 N,
393 K,
394 StrideA,
395 StrideB,
396 StrideC,
401 KBatch,
402 a_element_op,
403 b_element_op,
404 c_element_op);
405 }
406
407 // polymorphic
408 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
409 {
410 return std::make_unique<Invoker>(Invoker{});
411 }
412
413 // polymorphic
414 std::string GetTypeString() const override
415 {
416 auto str = std::stringstream();
417
418 std::map<LoopScheduler, std::string> LoopSchedToString{
419 {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
420
421 std::map<PipelineVersion, std::string> PipelineVersionToString{
423
424 // clang-format off
425 str << "DeviceGemmXdlSplitKCShuffle_LdsDirectLoad"
426 << "<"
427 << BlockSize << ", "
428 << MPerBlock << ", "
429 << NPerBlock << ", "
430 << K0PerBlock << ", "
431 << K1 << ", "
432 << MPerXDL << ", "
433 << NPerXDL << ", "
434 << MXdlPerWave << ", "
435 << NXdlPerWave << ", "
436 << ABlockTransferScalarPerVector << ", "
437 << BBlockTransferScalarPerVector << ", "
438 << CShuffleMRepeatPerShuffle << ", "
439 << CShuffleNRepeatPerShuffle << ", "
440 << getGemmSpecializationString(GemmSpec)
441 << ">"
442 << " LoopScheduler: "
443 << LoopSchedToString[LoopSched] << ", "
444 << "PipelineVersion: "
445 << PipelineVersionToString[PipelineVer] << ", "
446 << "Prefetch: "
447 << NumGemmKPrefetchStage;
448 // clang-format on
449
450 return str.str();
451 }
452};
453
454} // namespace device
455} // namespace tensor_operation
456} // namespace ck
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
bool is_lds_direct_load_supported()
Definition host_utility/device_prop.hpp:101
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
integral_constant< index_t, N > Number
Definition number.hpp:12
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
__global__ void kernel_gemm_xdlops_splitk_lds_direct_load(typename GridwiseGemm::Argument karg, const Block2CTileMap &b2c_map, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op)
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:35
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v4
Definition gridwise_gemm_pipeline_selector.hpp:22
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:99
Definition device_base.hpp:197
Definition device_gemm_splitk.hpp:26
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:126
Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t MPadded_, index_t NPadded_, index_t KPadded_, index_t K0Padded_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:127
BElementwiseOperation b_element_op
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:165
AElementwiseOperation a_element_op
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:164
CElementwiseOperation c_element_op
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:166
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:173
void Print(const Argument_ &karg)
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:175
float RunImp(const Argument &karg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:181
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:290
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:72
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:408
static constexpr auto I1
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:78
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:74
typename GridwiseGemm64::DefaultBlock2CTileMap DefaultBlock2CTileMap
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:169
GridwiseGemm_xdlops_splitk_lds_direct_load< BlockSize, ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferScalarPerVector, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferScalarPerVector, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeType > GridwiseGemmBase
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:83
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:123
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, ck::index_t KBatch=1) override
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:374
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:122
static bool IsSupportedArgument(const Argument &karg)
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:303
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, index_t KBatch)
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:338
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:333
std::string GetTypeString() const override
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:414
static constexpr auto I3
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:80
static auto MakeInvoker()
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:371
static constexpr auto I0
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:77
static constexpr auto I2
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:79
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:297
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp:75