gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp Source File

gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp Source File#

Composable Kernel: gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp Source File
gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
18
19namespace ck {
20
21// GEMM:
22// input : A[M, K]
23// input : B[N, K]
24// input : D0[M, N], D1[M, N], ...
25// output : E[M, N]
26// C = a_op(A) * b_op(B)
27// E = cde_op(C, D0, D1, ...)
28// Assume:
29// D0, D1, ... and E have the same layout
30template <typename ABDataType, // FIXME: don't assume A/B have same datatype
31 typename AccDataType,
32 typename CShuffleDataType,
33 typename DsDataType,
34 typename EDataType,
35 typename AElementwiseOperation,
36 typename BElementwiseOperation,
37 typename CDEElementwiseOperation,
38 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
39 typename AGridDesc_M_K,
40 typename BGridDesc_N_K,
41 typename DsGridDesc_M_N,
42 typename EGridDesc_M_N,
43 index_t NumGemmKPrefetchStage,
44 index_t BlockSize,
45 index_t MPerBlock,
46 index_t NPerBlock,
47 index_t KPerBlock,
48 index_t AK1Value,
49 index_t BK1Value,
50 index_t MPerXdl,
51 index_t NPerXdl,
52 index_t MXdlPerWave,
53 index_t NXdlPerWave,
54 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
55 typename ABlockTransferThreadClusterArrangeOrder,
56 typename ABlockTransferSrcAccessOrder,
57 index_t ABlockTransferSrcVectorDim,
58 index_t ABlockTransferSrcScalarPerVector,
59 index_t ABlockTransferDstScalarPerVector_AK1,
60 bool AThreadTransferSrcResetCoordinateAfterRun,
61 index_t ABlockLdsExtraM,
62 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
63 typename BBlockTransferThreadClusterArrangeOrder,
64 typename BBlockTransferSrcAccessOrder,
65 index_t BBlockTransferSrcVectorDim,
66 index_t BBlockTransferSrcScalarPerVector,
67 index_t BBlockTransferDstScalarPerVector_BK1,
68 bool BThreadTransferSrcResetCoordinateAfterRun,
69 index_t BBlockLdsExtraN,
70 index_t CShuffleMXdlPerWavePerShuffle,
71 index_t CShuffleNXdlPerWavePerShuffle,
72 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
73 index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
74 LoopScheduler LoopSched>
76{
77 static constexpr index_t NumDTensor = DsDataType::Size();
78
79 static constexpr auto I0 = Number<0>{};
80 static constexpr auto I1 = Number<1>{};
81 static constexpr auto I2 = Number<2>{};
82 static constexpr auto I3 = Number<3>{};
83 static constexpr auto I4 = Number<4>{};
84 static constexpr auto I5 = Number<5>{};
85 static constexpr auto I6 = Number<6>{};
86 static constexpr auto I7 = Number<7>{};
87
88 // K1 should be Number<...>
89 static constexpr auto AK1 = Number<AK1Value>{};
90 static constexpr auto BK1 = Number<BK1Value>{};
91 static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
92 static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
93
95
97
98 __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
99 {
100 // A matrix in LDS memory, src of blockwise copy
104 }
105
106 __host__ __device__ static constexpr auto GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1()
107 {
108 // A matrix in LDS memory, dst of blockwise copy
113 AK1,
114 I1));
115 }
116
117 __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
118 {
119 // B matrix in LDS memory, src of blockwise copy
123 }
124
125 __host__ __device__ static constexpr auto GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1()
126 {
127 // B matrix in LDS memory, dst of blockwise copy
132 BK1,
133 I1));
134 }
135
136 __host__ __device__ static constexpr auto
138 {
139 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
140 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
141
142 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
146 I1,
148
149 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
150 }
151
152 // ck::Tuple<const D0DataType*, const D1DataType*, ...>
153 static constexpr auto MakeDsGridPointer()
154 {
155 return generate_tuple(
156 [&](auto i) {
157 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
158
159 return static_cast<const DDataType*>(nullptr);
160 },
162 }
163
164 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
165 {
166 // LDS allocation for A and B: be careful of alignment
167 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
168 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
169
170 // lds max alignment
171 constexpr auto max_lds_align = math::lcm(AK1, BK1);
172
173 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
174 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
175
176 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
177 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
178
179 // LDS allocation for C shuffle in LDS
180 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
182
183 constexpr auto c_block_size =
184 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
185
186 return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
187 sizeof(ABDataType),
188 c_block_size * sizeof(CShuffleDataType));
189 }
190
191 // A desc for source in blockwise copy
192 __host__ __device__ static constexpr auto
193 MakeDefaultAGridDescriptor_AKB_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k,
194 const int split_k)
195 {
196 const auto MRaw = a_grid_desc_m_k.GetLength(I0);
197 const auto KRaw = a_grid_desc_m_k.GetLength(I1);
198
199 const index_t AK0 =
200 (math::integer_divide_ceil(KRaw, KPerBlock * split_k) * KPerBlock) / AK1;
201 const index_t K = split_k * AK0 * AK1;
202 const auto KPad = K - KRaw;
203
204 const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
205 a_grid_desc_m_k,
210 a_grid_desc_m_kpad,
215 }
216
217 // B desc for source in blockwise copy
218 __host__ __device__ static constexpr auto
219 MakeDefaultBGridDescriptor_BKB_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k,
220 const int split_k)
221 {
222 const auto NRaw = b_grid_desc_n_k.GetLength(I0);
223 const auto KRaw = b_grid_desc_n_k.GetLength(I1);
224
225 const index_t BK0 =
226 (math::integer_divide_ceil(KRaw, KPerBlock * split_k) * KPerBlock) / BK1;
227 const index_t K = split_k * BK0 * BK1;
228 const auto KPad = K - KRaw;
229
230 const auto b_grid_desc_n_kpad = transform_tensor_descriptor(
231 b_grid_desc_n_k,
235
237 b_grid_desc_n_kpad,
242 }
243
244 // E desc for destination in blockwise copy
245 template <typename EGridDescriptor_M_N>
246 __host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
247 const EGridDescriptor_M_N& e_grid_desc_m_n)
248 {
249 const auto M = e_grid_desc_m_n.GetLength(I0);
250 const auto N = e_grid_desc_m_n.GetLength(I1);
251
252 const auto MBlock = M / MPerBlock;
253 const auto NBlock = N / NPerBlock;
254
255 const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
256 e_grid_desc_m_n,
261
262 return e_grid_desc_mblock_mperblock_nblock_nperblock;
263 }
264
265 // Ds desc for source in blockwise copy
266 template <typename DsGridDescriptor_M_N>
267 __host__ __device__ static constexpr auto
269 const DsGridDescriptor_M_N& ds_grid_desc_m_n)
270 {
271 return generate_tuple(
272 [&](auto i) {
274 },
276 }
277
278 // return block_id to E matrix tile idx (m0, n0) mapping
279 __host__ __device__ static constexpr auto
280 MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, const int split_k)
281 {
283 e_grid_desc_m_n, 8, split_k);
284 }
285
287
288 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
289 template <typename AGridDesc_AKB_AK0_M_AK1,
290 typename BGridDesc_BKB_BK0_N_BK1,
291 typename Block2ETileMap>
292 __host__ __device__ static constexpr bool
293 CheckValidity(const AGridDesc_AKB_AK0_M_AK1& a_grid_desc_akb_ak0_m_ak1,
294 const BGridDesc_BKB_BK0_N_BK1& b_grid_desc_bkb_bk0_n_bk1,
295 const DsGridDesc_M_N& ds_grid_desc_m_n,
296 const EGridDesc_M_N& e_grid_desc_m_n,
297 const Block2ETileMap& block_2_etile_map)
298 {
299 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
300 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
301 "Invalid tuning param!");
302
303 const auto M = a_grid_desc_akb_ak0_m_ak1.GetLength(I2);
304 const auto N = b_grid_desc_bkb_bk0_n_bk1.GetLength(I2);
305 const auto K =
306 a_grid_desc_akb_ak0_m_ak1.GetLength(I1) * a_grid_desc_akb_ak0_m_ak1.GetLength(I3);
307
308 if(K != b_grid_desc_bkb_bk0_n_bk1.GetLength(I1) * b_grid_desc_bkb_bk0_n_bk1.GetLength(I3))
309 {
310 return false;
311 }
312 if(a_grid_desc_akb_ak0_m_ak1.GetLength(I0) != b_grid_desc_bkb_bk0_n_bk1.GetLength(I0))
313 {
314 return false;
315 }
316
317 // check consistency of desc
318 if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
319 {
320 return false;
321 }
322
323 bool valid = true;
324
325 static_for<0, NumDTensor, 1>{}([&](auto i) {
326 valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
327 N == ds_grid_desc_m_n[i].GetLength(I1));
328 });
329
330 if(!valid)
331 {
332 return false;
333 }
334
335 // check tile size
336 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
337 {
338 return false;
339 }
340
341 // check gridwise gemm pipeline
342 const auto num_k_loop = K / KPerBlock;
343
344 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
345 {
346 return false;
347 }
348
349 // check block-to-E-tile
350 if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
351 {
352 return false;
353 }
354
355 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
356 // check tensor size: cannot be larger than 2GB each
357 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
358
359 if(!(a_grid_desc_akb_ak0_m_ak1.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
360 b_grid_desc_bkb_bk0_n_bk1.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
361 e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
362 {
363 return false;
364 }
365
366 return true;
367 }
368
369 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
370 {
371 const index_t num_loop = K / KPerBlock;
372
373 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
374 }
375
377 remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AKB_AK0_M_AK1(AGridDesc_M_K{}, 1))>;
379 remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BKB_BK0_N_BK1(BGridDesc_N_K{}, 1))>;
382 EGridDesc_M_N{}))>;
385 DsGridDesc_M_N{}))>;
386
388 remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}, 1))>;
389
390 using DsGridPointer = decltype(MakeDsGridPointer());
391
392 template <bool HasMainKBlockLoop,
393 typename AGridDesc_AKB_AK0_M_AK1,
394 typename BGridDesc_BKB_BK0_N_BK1,
395 typename Block2ETileMap>
396 __device__ static void Run(const ABDataType* __restrict__ p_a_grid,
397 const ABDataType* __restrict__ p_b_grid,
398 DsGridPointer p_ds_grid,
399 EDataType* __restrict__ p_e_grid,
400 void* __restrict__ p_shared,
401 const AElementwiseOperation& a_element_op,
402 const BElementwiseOperation& b_element_op,
403 const CDEElementwiseOperation& cde_element_op,
404 const AGridDesc_AKB_AK0_M_AK1& a_grid_desc_akb_ak0_m_ak1,
405 const BGridDesc_BKB_BK0_N_BK1& b_grid_desc_bkb_bk0_n_bk1,
407 ds_grid_desc_mblock_mperblock_nblock_nperblock,
409 e_grid_desc_mblock_mperblock_nblock_nperblock,
410 const Block2ETileMap& block_2_etile_map)
411 {
412 const auto block_work_idx =
413 block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
414
415 if(block_work_idx[Number<0>{}] == 0)
416 {
418 p_b_grid,
419 p_ds_grid,
420 p_e_grid,
421 p_shared,
422 a_element_op,
423 b_element_op,
424 cde_element_op,
425 a_grid_desc_akb_ak0_m_ak1,
426 b_grid_desc_bkb_bk0_n_bk1,
427 ds_grid_desc_mblock_mperblock_nblock_nperblock,
428 e_grid_desc_mblock_mperblock_nblock_nperblock,
429 block_2_etile_map);
430 }
431 else
432 {
434 p_b_grid,
435 p_e_grid,
436 p_shared,
437 a_element_op,
438 b_element_op,
439 a_grid_desc_akb_ak0_m_ak1,
440 b_grid_desc_bkb_bk0_n_bk1,
441 ds_grid_desc_mblock_mperblock_nblock_nperblock,
442 e_grid_desc_mblock_mperblock_nblock_nperblock,
443 block_2_etile_map);
444 }
445 }
446 template <bool HasMainKBlockLoop,
447 typename AGridDesc_AKB_AK0_M_AK1,
448 typename BGridDesc_BKB_BK0_N_BK1,
449 typename Block2ETileMap>
450 __device__ static void Run0(const ABDataType* __restrict__ p_a_grid,
451 const ABDataType* __restrict__ p_b_grid,
452 DsGridPointer p_ds_grid,
453 EDataType* __restrict__ p_e_grid,
454 void* __restrict__ p_shared,
455 const AElementwiseOperation& a_element_op,
456 const BElementwiseOperation& b_element_op,
457 const CDEElementwiseOperation& cde_element_op,
458 const AGridDesc_AKB_AK0_M_AK1& a_grid_desc_akb_ak0_m_ak1,
459 const BGridDesc_BKB_BK0_N_BK1& b_grid_desc_bkb_bk0_n_bk1,
461 ds_grid_desc_mblock_mperblock_nblock_nperblock,
463 e_grid_desc_mblock_mperblock_nblock_nperblock,
464 const Block2ETileMap& block_2_etile_map)
465 {
466 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
467 p_a_grid, a_grid_desc_akb_ak0_m_ak1.GetElementSpaceSize());
468
469 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
470 p_b_grid, b_grid_desc_bkb_bk0_n_bk1.GetElementSpaceSize());
471
472 const auto ds_grid_buf = generate_tuple(
473 [&](auto i) {
475 p_ds_grid[i],
476 ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
477 },
479
481 p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
482
483 // divide block work by [M, N]
484 const auto block_work_idx =
485 block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
486
487 if(!block_2_etile_map.ValidCTileIndex(
488 make_tuple(block_work_idx[I1], block_work_idx[I2]),
489 make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
490 e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
491 {
492 return;
493 }
494
495 // HACK: this force m/n_block_data_idx_on_grid into SGPR
496 const index_t k_batch_id = block_work_idx[I0];
497
498 const index_t m_block_data_idx_on_grid =
499 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
500
501 const index_t n_block_data_idx_on_grid =
502 __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
503
504 // lds max alignment
505 constexpr auto max_lds_align = math::lcm(AK1, BK1);
506
507 // A matrix in LDS memory, dst of blockwise copy
508 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
509 constexpr auto a_block_desc_akb_ak0_m_ak1 =
511
512 // B matrix in LDS memory, dst of blockwise copy
513 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
514 constexpr auto b_block_desc_bkb_bk0_n_bk1 =
516
517 // A matrix blockwise copy
518 auto a_blockwise_copy =
520 AElementwiseOperation,
524 ABlockTransferThreadClusterLengths_AK0_M_AK1,
525 ABlockTransferThreadClusterArrangeOrder,
526 ABDataType,
527 ABDataType,
528 decltype(a_grid_desc_akb_ak0_m_ak1),
529 decltype(a_block_desc_akb_ak0_m_ak1),
530 ABlockTransferSrcAccessOrder,
532 ABlockTransferSrcVectorDim,
533 3,
534 ABlockTransferSrcScalarPerVector,
535 ABlockTransferDstScalarPerVector_AK1,
536 1,
537 1,
538 AThreadTransferSrcResetCoordinateAfterRun,
539 true,
540 NumGemmKPrefetchStage>(
541 a_grid_desc_akb_ak0_m_ak1,
542 make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
543 a_element_op,
544 a_block_desc_akb_ak0_m_ak1,
545 make_multi_index(0, 0, 0, 0),
547
548 // B matrix blockwise copy
549 auto b_blockwise_copy =
551 BElementwiseOperation,
555 BBlockTransferThreadClusterLengths_BK0_N_BK1,
556 BBlockTransferThreadClusterArrangeOrder,
557 ABDataType,
558 ABDataType,
559 decltype(b_grid_desc_bkb_bk0_n_bk1),
560 decltype(b_block_desc_bkb_bk0_n_bk1),
561 BBlockTransferSrcAccessOrder,
563 BBlockTransferSrcVectorDim,
564 3,
565 BBlockTransferSrcScalarPerVector,
566 BBlockTransferDstScalarPerVector_BK1,
567 1,
568 1,
569 BThreadTransferSrcResetCoordinateAfterRun,
570 true,
571 NumGemmKPrefetchStage>(
572 b_grid_desc_bkb_bk0_n_bk1,
573 make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
574 b_element_op,
575 b_block_desc_bkb_bk0_n_bk1,
576 make_multi_index(0, 0, 0, 0),
578
579 // GEMM definition
580 // c_mtx += transpose(a_mtx) * b_mtx
581 // a_mtx[K0PerBlock, MPerBlock] is in LDS
582 // b_mtx[K0PerBlock, NPerBlock] is in LDS
583 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
584 // register
585 // sanity check
586 constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
587 constexpr bool is_single_rate_mfma =
589 lcm_AK1_BK1 <= 4) ||
590 (is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
592 lcm_AK1_BK1 < 32))
593 ? true
594 : false;
595 constexpr auto is_scale_mfma = false;
596 constexpr index_t KPack = math::max(lcm_AK1_BK1,
597 MfmaSelector<ABDataType,
598 MPerXdl,
599 NPerXdl,
600 ABDataType,
601 is_single_rate_mfma,
602 is_scale_mfma>::selected_mfma.k_per_blk);
603
605 BlockSize,
606 ABDataType,
607 ABDataType,
608 AccDataType,
609 decltype(a_block_desc_ak0_m_ak1),
610 decltype(b_block_desc_bk0_n_bk1),
611 MPerXdl,
612 NPerXdl,
613 MXdlPerWave,
614 NXdlPerWave,
615 KPack,
616 LoopSched>();
617
618 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
619
620 // LDS allocation for A and B: be careful of alignment
621 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
622 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
623
625 static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
626
628 static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned,
629 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
630
631 constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0);
632 constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock / BK1, 0, 0);
633
634 // gridwise GEMM pipeline
635 const auto gridwise_gemm_pipeline =
637
638 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
639 (a_grid_desc_akb_ak0_m_ak1.GetLength(I1) * a_grid_desc_akb_ak0_m_ak1.GetLength(I3)) /
640 KPerBlock);
641
642 gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_akb_ak0_m_ak1,
643 a_block_desc_akb_ak0_m_ak1,
644 a_blockwise_copy,
645 a_grid_buf,
646 a_block_buf,
647 a_block_slice_copy_step,
648 b_grid_desc_bkb_bk0_n_bk1,
649 b_block_desc_bkb_bk0_n_bk1,
650 b_blockwise_copy,
651 b_grid_buf,
652 b_block_buf,
653 b_block_slice_copy_step,
654 blockwise_gemm,
655 c_thread_buf,
656 num_k_block_main_loop);
657
658 // shuffle C and write out
659 {
660 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
661 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
662 "wrong!");
663
664 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
665 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
666
667 // TODO: hacky, fix it!
668 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
669 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
670
671 // TODO: hacky, fix it!
672 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
673 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
674 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
675
676 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
677 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
678 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
679 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
680 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
681 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
682 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
683 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
684
685 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
687
688 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
689 static_cast<CShuffleDataType*>(p_shared),
690 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
691
692 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
693 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
697 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
698 M1, // M1 = MWave
699 M2, // M2 * M3 * M4 = MPerXdl
700 M3,
701 M4)),
704 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
705 N1, // N1 = NWave
706 N2))), // N2 = NPerXdl
710
711 // calculate origin of thread output tensor on global memory
712 // blockwise GEMM c matrix starting index
713 const auto c_thread_mtx_on_block =
714 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
715
716 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
717 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
718
719 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
721 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
724
725 const auto m_thread_data_on_block_idx =
726 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
727 make_multi_index(m_thread_data_on_block));
728
729 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
734
735 const auto n_thread_data_on_block_idx =
736 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
737 make_multi_index(n_thread_data_on_block));
738
739 // shuffle: threadwise copy C from VGPR to LDS
740 auto c_thread_copy_vgpr_to_lds =
742 CShuffleDataType,
743 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
744 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
746 Sequence<CShuffleMXdlPerWavePerShuffle,
747 CShuffleNXdlPerWavePerShuffle,
748 I1,
749 I1,
750 M2,
751 I1,
752 M4,
753 I1>,
755 7,
756 1,
758 1,
759 true>{
760 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
762 0,
763 m_thread_data_on_block_idx[I1],
764 n_thread_data_on_block_idx[I1],
765 m_thread_data_on_block_idx[I2],
766 m_thread_data_on_block_idx[I3],
767 m_thread_data_on_block_idx[I4],
768 n_thread_data_on_block_idx[I2]),
770 {
771 // tuple of reference to C/Ds tensor descriptors
772 const auto c_ds_desc_refs = concat_tuple_of_reference(
773 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
774 generate_tie([&](auto i) -> const auto& // return type should be reference
775 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
777
778 // tuple of reference to C/Ds tensor descriptors
779 const auto c_ds_buf_refs = concat_tuple_of_reference(
780 tie(c_shuffle_block_buf),
781 generate_tie([&](auto i) -> const auto& // return type should be reference
782 { return ds_grid_buf[i]; },
784
785 // tuple of starting index of C/Ds blockwise copy
786 const auto idx_c_ds_block_begin = container_concat(
787 make_tuple(make_multi_index(0, 0, 0, 0)),
789 [&](auto) {
790 return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0);
791 },
793
794 // blockwise copy C/D/E between LDS and global
795 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
797 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
799 decltype(c_ds_desc_refs),
800 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
801 CDEElementwiseOperation,
802 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
803 // Sequence support
804 // arbitray type
805 Sequence<1,
806 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
807 1,
808 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
809 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
810 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
811 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
812 3, // index_t VectorDim,
813 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
817 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
818 Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
819 {c_ds_desc_refs,
820 idx_c_ds_block_begin,
821 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
822 make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)),
823 cde_element_op};
824
825 // space filling curve for threadwise C in VGPR before shuffle
826 constexpr auto sfc_c_vgpr =
829 Sequence<CShuffleMXdlPerWavePerShuffle,
830 CShuffleNXdlPerWavePerShuffle,
831 1,
832 1,
833 M2,
834 1,
835 M4,
836 1>>{};
837
838 // space filling curve for shuffled blockwise C/D/E
839 constexpr auto sfc_cde_block =
842 Sequence<1,
843 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
844 1,
845 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
846
847 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
848
849 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
850
851 static_for<0, num_access, 1>{}([&](auto access_id) {
852 // make sure it's safe to write to LDS
854
855 // each thread write its data from VGPR to LDS
856 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
857 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
858 c_thread_buf,
859 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
860 c_shuffle_block_buf);
861
862 // make sure it's safe to read from LDS
864
865 // each block copy its data from LDS to global
866 cde_block_copy_lds_and_global.Run(
867 c_ds_desc_refs,
868 c_ds_buf_refs,
869 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
870 tie(e_grid_buf));
871
872 if constexpr(access_id < num_access - 1)
873 {
874 constexpr auto cde_lds_and_global_step =
875 sfc_cde_block.GetForwardStep(access_id);
876
877 // move on Ds
878 static_for<0, NumDTensor, 1>{}([&](auto i) {
879 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
880 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
881 });
882
883 // move on E
884 cde_block_copy_lds_and_global.MoveDstSliceWindow(
885 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
886 I0,
887 cde_lds_and_global_step);
888 }
889 });
890 }
891 }
892 }
893
894 template <bool HasMainKBlockLoop,
895 typename AGridDesc_AKB_AK0_M_AK1,
896 typename BGridDesc_BKB_BK0_N_BK1,
897 typename Block2ETileMap>
898 __device__ static void Run1(const ABDataType* __restrict__ p_a_grid,
899 const ABDataType* __restrict__ p_b_grid,
900 EDataType* __restrict__ p_e_grid,
901 void* __restrict__ p_shared,
902 const AElementwiseOperation& a_element_op,
903 const BElementwiseOperation& b_element_op,
904 const AGridDesc_AKB_AK0_M_AK1& a_grid_desc_akb_ak0_m_ak1,
905 const BGridDesc_BKB_BK0_N_BK1& b_grid_desc_bkb_bk0_n_bk1,
908 e_grid_desc_mblock_mperblock_nblock_nperblock,
909 const Block2ETileMap& block_2_etile_map)
910 {
911 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
912 p_a_grid, a_grid_desc_akb_ak0_m_ak1.GetElementSpaceSize());
913
914 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
915 p_b_grid, b_grid_desc_bkb_bk0_n_bk1.GetElementSpaceSize());
916
918 p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
919
920 // divide block work by [M, N]
921 const auto block_work_idx =
922 block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
923
924 if(!block_2_etile_map.ValidCTileIndex(
925 make_tuple(block_work_idx[I1], block_work_idx[I2]),
926 make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
927 e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
928 {
929 return;
930 }
931
932 // HACK: this force m/n_block_data_idx_on_grid into SGPR
933 const index_t k_batch_id = block_work_idx[I0];
934
935 const index_t m_block_data_idx_on_grid =
936 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
937
938 const index_t n_block_data_idx_on_grid =
939 __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
940
941 // lds max alignment
942 constexpr auto max_lds_align = math::lcm(AK1, BK1);
943
944 // A matrix in LDS memory, dst of blockwise copy
945 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
946 constexpr auto a_block_desc_akb_ak0_m_ak1 =
948
949 // B matrix in LDS memory, dst of blockwise copy
950 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
951 constexpr auto b_block_desc_bkb_bk0_n_bk1 =
953
954 // A matrix blockwise copy
955 auto a_blockwise_copy =
957 AElementwiseOperation,
961 ABlockTransferThreadClusterLengths_AK0_M_AK1,
962 ABlockTransferThreadClusterArrangeOrder,
963 ABDataType,
964 ABDataType,
965 decltype(a_grid_desc_akb_ak0_m_ak1),
966 decltype(a_block_desc_akb_ak0_m_ak1),
967 ABlockTransferSrcAccessOrder,
969 ABlockTransferSrcVectorDim,
970 3,
971 ABlockTransferSrcScalarPerVector,
972 ABlockTransferDstScalarPerVector_AK1,
973 1,
974 1,
975 AThreadTransferSrcResetCoordinateAfterRun,
976 true,
977 NumGemmKPrefetchStage>(
978 a_grid_desc_akb_ak0_m_ak1,
979 make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
980 a_element_op,
981 a_block_desc_akb_ak0_m_ak1,
982 make_multi_index(0, 0, 0, 0),
984
985 // B matrix blockwise copy
986 auto b_blockwise_copy =
988 BElementwiseOperation,
992 BBlockTransferThreadClusterLengths_BK0_N_BK1,
993 BBlockTransferThreadClusterArrangeOrder,
994 ABDataType,
995 ABDataType,
996 decltype(b_grid_desc_bkb_bk0_n_bk1),
997 decltype(b_block_desc_bkb_bk0_n_bk1),
998 BBlockTransferSrcAccessOrder,
1000 BBlockTransferSrcVectorDim,
1001 3,
1002 BBlockTransferSrcScalarPerVector,
1003 BBlockTransferDstScalarPerVector_BK1,
1004 1,
1005 1,
1006 BThreadTransferSrcResetCoordinateAfterRun,
1007 true,
1008 NumGemmKPrefetchStage>(
1009 b_grid_desc_bkb_bk0_n_bk1,
1010 make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
1011 b_element_op,
1012 b_block_desc_bkb_bk0_n_bk1,
1013 make_multi_index(0, 0, 0, 0),
1015
1016 // GEMM definition
1017 // c_mtx += transpose(a_mtx) * b_mtx
1018 // a_mtx[K0PerBlock, MPerBlock] is in LDS
1019 // b_mtx[K0PerBlock, NPerBlock] is in LDS
1020 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
1021 // register
1022 // sanity check
1023 constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
1024 constexpr bool is_single_rate_mfma =
1026 lcm_AK1_BK1 <= 4) ||
1027 (is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
1029 lcm_AK1_BK1 < 32))
1030 ? true
1031 : false;
1032 constexpr auto is_scale_mfma = false;
1033 constexpr index_t KPack = math::max(lcm_AK1_BK1,
1034 MfmaSelector<ABDataType,
1035 MPerXdl,
1036 NPerXdl,
1037 ABDataType,
1038 is_single_rate_mfma,
1039 is_scale_mfma>::selected_mfma.k_per_blk);
1040
1042 BlockSize,
1043 ABDataType,
1044 ABDataType,
1045 AccDataType,
1046 decltype(a_block_desc_ak0_m_ak1),
1047 decltype(b_block_desc_bk0_n_bk1),
1048 MPerXdl,
1049 NPerXdl,
1050 MXdlPerWave,
1051 NXdlPerWave,
1052 KPack,
1053 LoopSched>();
1054
1055 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
1056
1057 // LDS allocation for A and B: be careful of alignment
1058 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1059 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1060
1062 static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1063
1065 static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned,
1066 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1067
1068 constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0);
1069 constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock / BK1, 0, 0);
1070
1071 // gridwise GEMM pipeline
1072 const auto gridwise_gemm_pipeline =
1074
1075 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1076 (a_grid_desc_akb_ak0_m_ak1.GetLength(I1) * a_grid_desc_akb_ak0_m_ak1.GetLength(I3)) /
1077 KPerBlock);
1078
1079 gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_akb_ak0_m_ak1,
1080 a_block_desc_akb_ak0_m_ak1,
1081 a_blockwise_copy,
1082 a_grid_buf,
1083 a_block_buf,
1084 a_block_slice_copy_step,
1085 b_grid_desc_bkb_bk0_n_bk1,
1086 b_block_desc_bkb_bk0_n_bk1,
1087 b_blockwise_copy,
1088 b_grid_buf,
1089 b_block_buf,
1090 b_block_slice_copy_step,
1091 blockwise_gemm,
1092 c_thread_buf,
1093 num_k_block_main_loop);
1094
1095 // shuffle C and write out
1096 {
1097 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1098 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1099 "wrong!");
1100
1101 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1102 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1103
1104 // TODO: hacky, fix it!
1105 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1106 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1107
1108 // TODO: hacky, fix it!
1109 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1110 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1111 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1112
1113 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1114 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1115 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1116 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1117 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1118 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1119 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1120 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1121
1122 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1124
1125 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1126 static_cast<CShuffleDataType*>(p_shared),
1127 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1128
1129 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1130 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1131 make_tuple(
1134 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1135 M1, // M1 = MWave
1136 M2, // M2 * M3 * M4 = MPerXdl
1137 M3,
1138 M4)),
1141 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1142 N1, // N1 = NWave
1143 N2))), // N2 = NPerXdl
1145 make_tuple(
1147
1148 // calculate origin of thread output tensor on global memory
1149 // blockwise GEMM c matrix starting index
1150 const auto c_thread_mtx_on_block =
1151 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1152
1153 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1154 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1155
1156 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1158 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1161
1162 const auto m_thread_data_on_block_idx =
1163 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1164 make_multi_index(m_thread_data_on_block));
1165
1166 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1171
1172 const auto n_thread_data_on_block_idx =
1173 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1174 make_multi_index(n_thread_data_on_block));
1175
1176 // shuffle: threadwise copy C from VGPR to LDS
1177 auto c_thread_copy_vgpr_to_lds =
1179 CShuffleDataType,
1180 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1181 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1183 Sequence<CShuffleMXdlPerWavePerShuffle,
1184 CShuffleNXdlPerWavePerShuffle,
1185 I1,
1186 I1,
1187 M2,
1188 I1,
1189 M4,
1190 I1>,
1192 7,
1193 1,
1195 1,
1196 true>{
1197 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1199 0,
1200 m_thread_data_on_block_idx[I1],
1201 n_thread_data_on_block_idx[I1],
1202 m_thread_data_on_block_idx[I2],
1203 m_thread_data_on_block_idx[I3],
1204 m_thread_data_on_block_idx[I4],
1205 n_thread_data_on_block_idx[I2]),
1207 {
1208 // shuffle: blockwise copy C from LDS to global
1209 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1210 ThisThreadBlock, // ThreadGroup
1211 ck::tensor_operation::element_wise::PassThrough, // ElementwiseOperation,
1212 EGlobalMemoryDataOperation, // DstInMemOp,
1213 Sequence<1,
1214 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1215 1,
1216 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1217 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1218 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1219 CShuffleDataType, // typename SrcData,
1220 EDataType, // typename DstData,
1221 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1222 decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
1223 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1224 3, // index_t VectorDim,
1225 CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1226 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1227 false> // bool ThreadTransferDstResetCoordinateAfterRun>
1228 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1229 make_multi_index(0, 0, 0, 0),
1230 e_grid_desc_mblock_mperblock_nblock_nperblock,
1231 make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0),
1233
1234 // space filling curve for threadwise C in VGPR
1235 constexpr auto sfc_c_vgpr =
1238 Sequence<CShuffleMXdlPerWavePerShuffle,
1239 CShuffleNXdlPerWavePerShuffle,
1240 1,
1241 1,
1242 M2,
1243 1,
1244 M4,
1245 1>>{};
1246
1247 // space filling curve for shuffled blockwise C in global mem
1248 constexpr auto sfc_c_global =
1251 Sequence<1,
1252 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1253 1,
1254 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1255
1256 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1257
1258 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1259
1260 static_for<0, num_access, 1>{}([&](auto access_id) {
1261 // make sure it's safe to write to LDS
1263
1264 // each thread write its data from VGPR to LDS
1265 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1266 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1267 c_thread_buf,
1268 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1269 c_shuffle_block_buf);
1270
1271 // make sure it's safe to read from LDS
1273
1274 // each block copy its data from LDS to global
1275 c_shuffle_block_copy_lds_to_global.Run(
1276 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1277 c_shuffle_block_buf,
1278 e_grid_desc_mblock_mperblock_nblock_nperblock,
1279 e_grid_buf);
1280
1281 if constexpr(access_id < num_access - 1)
1282 {
1283 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1284
1285 // move on C
1286 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1287 e_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1288 }
1289 });
1290 }
1291 }
1292 }
1293};
1294
1295} // namespace ck
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition blockwise_gemm_xdlops.hpp:620
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
constexpr auto GridwiseGemmPipeline_v1_Selector()
Definition gridwise_gemm_pipeline_v1.hpp:758
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition utility/sequence.hpp:925
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
Definition block_to_ctile_map.hpp:541
Definition gridwise_gemm_pipeline_v1.hpp:13
Definition gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp:76
static __device__ void Run0(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, DsGridPointer p_ds_grid, EDataType *__restrict__ p_e_grid, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const AGridDesc_AKB_AK0_M_AK1 &a_grid_desc_akb_ak0_m_ak1, const BGridDesc_BKB_BK0_N_BK1 &b_grid_desc_bkb_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp:450
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const ADataType *__restrict__ p_b_grid, DsGridPointer p_ds_grid, EDataType *__restrict__ p_e_grid, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const AGridDesc_AKB_AK0_M_AK1 &a_grid_desc_akb_ak0_m_ak1, const BGridDesc_BKB_BK0_N_BK1 &b_grid_desc_bkb_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp:396
static __device__ void Run1(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, EDataType *__restrict__ p_e_grid, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const AGridDesc_AKB_AK0_M_AK1 &a_grid_desc_akb_ak0_m_ak1, const BGridDesc_BKB_BK0_N_BK1 &b_grid_desc_bkb_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp:898
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition thread_group_tensor_slice_transfer_v7.hpp:42
Definition threadwise_tensor_slice_transfer.hpp:39
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340