device_gemm_xdl_streamk.hpp Source File

device_gemm_xdl_streamk.hpp Source File#

Composable Kernel: device_gemm_xdl_streamk.hpp Source File
device_gemm_xdl_streamk.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24template <typename ADataType,
25 typename BDataType,
26 typename CDataType,
27 typename AccDataType,
28 typename ALayout,
29 typename BLayout,
30 typename CLayout,
31 typename AElementwiseOperation,
32 typename BElementwiseOperation,
33 typename CElementwiseOperation,
34 ck::index_t BlockSize,
35 ck::index_t MPerBlock,
36 ck::index_t NPerBlock,
37 ck::index_t K0PerBlock,
38 ck::index_t K1,
39 ck::index_t MPerXDL,
40 ck::index_t NPerXDL,
41 ck::index_t MXdlPerWave,
42 ck::index_t NXdlPerWave,
43 typename ABlockTransferThreadClusterLengths_K0_M_K1,
44 typename ABlockTransferThreadClusterArrangeOrder,
45 typename ABlockTransferSrcAccessOrder,
46 ck::index_t ABlockTransferSrcVectorDim,
47 ck::index_t ABlockTransferSrcScalarPerVector,
48 ck::index_t ABlockTransferDstScalarPerVector_K1,
49 ck::index_t ABlockLdsAddExtraM,
50 typename BBlockTransferThreadClusterLengths_K0_N_K1,
51 typename BBlockTransferThreadClusterArrangeOrder,
52 typename BBlockTransferSrcAccessOrder,
53 ck::index_t BBlockTransferSrcVectorDim,
54 ck::index_t BBlockTransferSrcScalarPerVector,
55 ck::index_t BBlockTransferDstScalarPerVector_K1,
56 ck::index_t BBlockLdsAddExtraN,
57 index_t CShuffleMRepeatPerShuffle,
58 index_t CShuffleNRepeatPerShuffle,
59 typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
60 index_t CBlockTransferScalarPerVector_NWaveNPerXDL>
62 BLayout,
63 CLayout,
64 ADataType,
65 BDataType,
66 CDataType,
67 AElementwiseOperation,
68 BElementwiseOperation,
69 CElementwiseOperation>
70{
72 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
73 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
74
75 static constexpr auto I0 = Number<0>{};
76 static constexpr auto I1 = Number<1>{};
77 static constexpr auto I2 = Number<2>{};
78 static constexpr auto I3 = Number<3>{};
79
80 template <index_t NXdlPerWave_>
82 BlockSize,
84 NPerBlock,
85 K0PerBlock * K1,
87 ADataType, // TODO: distinguish A/B datatype
88 AccDataType,
89 CDataType,
90 ALayout,
91 BLayout,
92 CLayout,
93 AElementwiseOperation,
94 BElementwiseOperation,
95 CElementwiseOperation,
96 MPerBlock,
97 NPerBlock,
98 K0PerBlock,
99 MPerXDL,
100 NPerXDL,
101 K1,
102 MXdlPerWave,
103 NXdlPerWave_,
104 ABlockTransferThreadClusterLengths_K0_M_K1,
105 ABlockTransferThreadClusterArrangeOrder,
106 ABlockTransferSrcAccessOrder,
107 ABlockTransferSrcVectorDim,
108 ABlockTransferSrcScalarPerVector,
109 ABlockTransferDstScalarPerVector_K1,
110 false, // AThreadTransferSrcResetCoordinateAfterRun,
111 ABlockLdsAddExtraM,
112 BBlockTransferThreadClusterLengths_K0_N_K1,
113 BBlockTransferThreadClusterArrangeOrder,
114 BBlockTransferSrcAccessOrder,
115 BBlockTransferSrcVectorDim,
116 BBlockTransferSrcScalarPerVector,
117 BBlockTransferDstScalarPerVector_K1,
118 false, // BThreadTransferSrcResetCoordinateAfterRun,
119 BBlockLdsAddExtraN,
120 CShuffleMRepeatPerShuffle,
121 CShuffleNRepeatPerShuffle,
122 CBlockTransferScalarPerVector_NWaveNPerXDL,
123 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
126
127 using Argument = typename GridwiseGemm64::Argument;
128
129 // Invoker
130 struct Invoker : public BaseInvoker
131 {
132 template <typename Argument_>
133 void Print(const Argument_& karg)
134 {
135 karg.Print();
136 }
137
138 template <typename GridwiseGemm>
139 float RunImp(const typename GridwiseGemm::Argument& karg,
140 const StreamConfig& stream_config = StreamConfig{})
141 {
142 if(stream_config.log_level_ > 0)
143 {
144 Print(karg);
145 }
146 if(!GridwiseGemm::CheckValidity(karg))
147 {
148 throw std::runtime_error(
149 "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid "
150 "setting");
151 }
152
153 dim3 grid_dims = karg.block_mapping.get_grid_dims();
154
155 float ave_time = 0;
156
158
159 // TODO: remove clear buffer for streamk kernels
160 if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
162 {
163 hipGetErrorString(hipMemsetAsync(karg.p_c_grid,
164 0,
165 karg.M * karg.N * sizeof(CDataType),
166 stream_config.stream_id_));
167 ave_time = launch_and_time_kernel(stream_config,
168 kernel,
169 grid_dims,
170 dim3(BlockSize),
171 0,
172 karg.p_a_grid,
173 karg.p_b_grid,
174 karg.p_c_grid,
175 karg.p_workspace_,
176 karg.M,
177 karg.N,
178 karg.K,
179 karg.StrideA,
180 karg.StrideB,
181 karg.StrideC,
182 karg.block_mapping);
183 }
184 else if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
186 {
187 char* workspace_semaphore = reinterpret_cast<char*>(karg.p_workspace_) +
188 karg.block_mapping.get_workspace_size_for_acc(
189 sizeof(typename GridwiseGemm::FloatAcc));
190 auto preprocess = [&]() {
191 hipGetErrorString(
192 hipMemsetAsync(workspace_semaphore,
193 0,
194 karg.block_mapping.get_workspace_size_for_semaphore(),
195 stream_config.stream_id_));
196 };
197
198 ave_time = launch_and_time_kernel_with_preprocess(stream_config,
199 preprocess,
200 kernel,
201 grid_dims,
202 dim3(BlockSize),
203 0,
204 karg.p_a_grid,
205 karg.p_b_grid,
206 karg.p_c_grid,
207 karg.p_workspace_,
208 karg.M,
209 karg.N,
210 karg.K,
211 karg.StrideA,
212 karg.StrideB,
213 karg.StrideC,
214 karg.block_mapping);
215 }
216
217 return ave_time;
218 }
219
221
222 // polymorphic
223 float Run(const BaseArgument* p_arg,
224 const StreamConfig& stream_config = StreamConfig{}) override
225 {
226 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
227 }
228 };
229
230 size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
231 {
232 const Argument* p_arg = dynamic_cast<const Argument*>(pArg);
233 if(get_warp_size() == 64)
234 {
235 if constexpr(GridwiseGemm64::Block2CTileMap::ReductionStrategy ==
237 {
238 return p_arg->block_mapping.get_workspace_size(
239 sizeof(typename GridwiseGemm64::FloatAcc));
240 }
241 }
242 else
243 {
244 if constexpr(GridwiseGemm32::Block2CTileMap::ReductionStrategy ==
246 {
247 return p_arg->block_mapping.get_workspace_size(
248 sizeof(typename GridwiseGemm32::FloatAcc));
249 }
250 }
251 return 0;
252 }
253
255 void* p_workspace,
256 const StreamConfig& = StreamConfig{}) const override
257 {
258 Argument* pArg_ = dynamic_cast<Argument*>(pArg);
259
260 pArg_->p_workspace_ = p_workspace;
261 }
262
263 static constexpr bool IsValidCompilationParameter()
264 {
265 // TODO: properly implement this check
266 return true;
267 }
268
269 static bool IsSupportedArgument(const Argument& karg)
270 {
272 {
273 return false;
274 }
275 if(get_warp_size() == 64)
276 {
277 if constexpr(NXdlPerWave64 > 0)
278 {
280 }
281 }
282 else
283 {
284 if constexpr(NXdlPerWave32 > 0)
285 {
287 reinterpret_cast<const typename GridwiseGemm32::Argument&>(karg));
288 }
289 }
290 return false;
291 }
292
293 // polymorphic
294 bool IsSupportedArgument(const BaseArgument* p_arg) override
295 {
296 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
297 }
298
299 static auto MakeArgument(const ADataType* p_a,
300 const BDataType* p_b,
301 CDataType* p_c,
302 index_t M,
303 index_t N,
304 index_t K,
305 index_t StrideA,
306 index_t StrideB,
307 index_t StrideC,
308 AElementwiseOperation,
309 BElementwiseOperation,
310 CElementwiseOperation,
311 uint32_t NumSKBlocks = 0xffffffff)
312 {
313 int num_cu;
314 hipError_t rtn;
315 int occupancy = [&]() {
316 int occupancy_ = 0;
317 if(get_warp_size() == 64)
318 {
319 if constexpr(NXdlPerWave64 > 0)
320 {
322 rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
323 &occupancy_,
324 kernel,
325 BlockSize,
327 hip_check_error(rtn);
328 }
329 }
330 else
331 {
332 if constexpr(NXdlPerWave32 > 0)
333 {
335 rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
336 &occupancy_,
337 kernel,
338 BlockSize,
340 hip_check_error(rtn);
341 }
342 }
343 return occupancy_;
344 }();
345
346 hipDeviceProp_t dev_prop;
347 hipDevice_t dev;
348 rtn = hipGetDevice(&dev);
349 hip_check_error(rtn);
350 rtn = hipGetDeviceProperties(&dev_prop, dev);
351 hip_check_error(rtn);
352 num_cu = dev_prop.multiProcessorCount;
353
354 return Argument{p_a,
355 p_b,
356 p_c,
357 M,
358 N,
359 K,
360 StrideA,
361 StrideB,
362 StrideC,
363 static_cast<uint32_t>(num_cu),
364 static_cast<uint32_t>(occupancy),
365 NumSKBlocks};
366 }
367
368 static auto MakeInvoker() { return Invoker{}; }
369
370 // polymorphic
371 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
372 const void* p_b,
373 void* p_c,
374 index_t M,
375 index_t N,
376 index_t K,
377 index_t StrideA,
378 index_t StrideB,
379 index_t StrideC,
380 AElementwiseOperation,
381 BElementwiseOperation,
382 CElementwiseOperation,
383 index_t NumSKBlocks = 0) override
384 {
385 int num_cu;
386 hipError_t rtn;
387
388 int occupancy = [&]() {
389 int occupancy_ = 0;
390 if(get_warp_size() == 64)
391 {
392 if constexpr(NXdlPerWave64 > 0)
393 {
395 rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
396 &occupancy_,
397 kernel,
398 BlockSize,
400 hip_check_error(rtn);
401 }
402 }
403 else
404 {
405 if constexpr(NXdlPerWave32 > 0)
406 {
408 rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
409 &occupancy_,
410 kernel,
411 BlockSize,
413 hip_check_error(rtn);
414 }
415 }
416 return occupancy_;
417 }();
418
419 hipDeviceProp_t dev_prop;
420 hipDevice_t dev;
421 rtn = hipGetDevice(&dev);
422 hip_check_error(rtn);
423 rtn = hipGetDeviceProperties(&dev_prop, dev);
424 hip_check_error(rtn);
425 num_cu = dev_prop.multiProcessorCount;
426
427 return std::make_unique<Argument>(reinterpret_cast<const ADataType*>(p_a),
428 reinterpret_cast<const BDataType*>(p_b),
429 reinterpret_cast<CDataType*>(p_c),
430 M,
431 N,
432 K,
433 StrideA,
434 StrideB,
435 StrideC,
436 static_cast<uint32_t>(num_cu),
437 static_cast<uint32_t>(occupancy),
438 static_cast<uint32_t>(NumSKBlocks));
439 }
440
441 // polymorphic
442 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
443 {
444 return std::make_unique<Invoker>(Invoker{});
445 }
446
447 // polymorphic
448 std::string GetTypeString() const override
449 {
452 }
453};
454
455} // namespace device
456} // namespace tensor_operation
457} // namespace ck
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:91
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
@ Atomic
Definition block_to_ctile_map.hpp:1012
@ Reduction
Definition block_to_ctile_map.hpp:1013
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB *p_a_grid, const typename GridwiseGemm::FloatAB *p_b_grid, typename GridwiseGemm::FloatC *p_c_grid, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, typename GridwiseGemm::Block2CTileMap block_mapping)
Definition gridwise_gemm_xdlops_streamk.hpp:28
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
unsigned int uint32_t
Definition stdint.h:126
Definition ck/stream_config.hpp:10
Definition block_to_ctile_map.hpp:1022
Definition gridwise_gemm_xdlops_streamk.hpp:115
Definition device_base.hpp:197
Definition device_gemm_streamk.hpp:25
Definition device_gemm_xdl_streamk.hpp:131
void Print(const Argument_ &karg)
Definition device_gemm_xdl_streamk.hpp:133
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_streamk.hpp:223
float RunImp(const typename GridwiseGemm::Argument &karg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_streamk.hpp:139
Definition device_gemm_xdl_streamk.hpp:70
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_streamk.hpp:263
static constexpr auto I3
Definition device_gemm_xdl_streamk.hpp:78
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, uint32_t NumSKBlocks=0xffffffff)
Definition device_gemm_xdl_streamk.hpp:299
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_streamk.hpp:294
static auto MakeInvoker()
Definition device_gemm_xdl_streamk.hpp:368
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_streamk.hpp:442
std::string GetTypeString() const override
Definition device_gemm_xdl_streamk.hpp:448
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_streamk.hpp:124
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, index_t NumSKBlocks=0) override
Definition device_gemm_xdl_streamk.hpp:371
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_streamk.hpp:125
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk< BlockSize, BlockToCTileMap_GemmStreamK< MPerBlock, NPerBlock, K0PerBlock *K1, StreamKReductionStrategy::Atomic >, ADataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock > GridwiseGemmBase
Definition device_gemm_xdl_streamk.hpp:81
typename GridwiseGemm64::Argument Argument
Definition device_gemm_xdl_streamk.hpp:127
static constexpr auto I2
Definition device_gemm_xdl_streamk.hpp:77
static bool IsSupportedArgument(const Argument &karg)
Definition device_gemm_xdl_streamk.hpp:269
static constexpr auto I0
Definition device_gemm_xdl_streamk.hpp:75
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_streamk.hpp:72
void SetWorkSpacePointer(BaseArgument *pArg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_gemm_xdl_streamk.hpp:254
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_streamk.hpp:73
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition device_gemm_xdl_streamk.hpp:230
static constexpr auto I1
Definition device_gemm_xdl_streamk.hpp:76