gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp Source File

gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp Source File#

Composable Kernel: gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp Source File
gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
17
20
21namespace ck {
22
23// GEMM:
24// input : A[M, K]
25// input : B[N, K]
26// input : D0[M, N], D1[M, N], ...
27// output : E[M, N]
28// C = a_op(A) * b_op(B)
29// E = cde_op(C, D0, D1, ...)
30// Assume:
31// D0, D1, ... and E have the same layout
32template <typename ADataType, // FIXME: don't assume A/B have same datatype
33 typename BDataType,
34 typename AccDataType,
35 typename CShuffleDataType,
36 typename DsDataType,
37 typename EDataType,
38 typename ComputeType,
39 typename AElementwiseOperation,
40 typename BElementwiseOperation,
41 typename CDEElementwiseOperation,
42 index_t NumGemmKPrefetchStage,
43 index_t BlockSize,
44 index_t MPerBlock,
45 index_t NPerBlock,
46 index_t KPerBlock,
47 index_t AK1Value,
48 index_t BK1Value,
49 index_t MPerXdl,
50 index_t NPerXdl,
51 index_t MXdlPerWave,
52 index_t NXdlPerWave,
53 typename ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
54 typename ABlockTransferThreadClusterArrangeOrder,
55 typename ABlockTransferSrcAccessOrder,
56 index_t ABlockTransferSrcVectorDim,
57 index_t ABlockTransferSrcScalarPerVector,
58 index_t ABlockTransferDstScalarPerVector_AK1,
59 bool AThreadTransferSrcResetCoordinateAfterRun,
60 index_t ABlockLdsExtraM,
61 typename BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
62 typename BBlockTransferThreadClusterArrangeOrder,
63 typename BBlockTransferSrcAccessOrder,
64 index_t BBlockTransferSrcVectorDim,
65 index_t BBlockTransferSrcScalarPerVector,
66 index_t BBlockTransferDstScalarPerVector_BK1,
67 bool BThreadTransferSrcResetCoordinateAfterRun,
68 index_t BBlockLdsExtraN,
69 index_t CShuffleMXdlPerWavePerShuffle,
70 index_t CShuffleNXdlPerWavePerShuffle,
71 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
72 index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
73 LoopScheduler LoopSched,
76{
77 static constexpr index_t NumDTensor = DsDataType::Size();
78
80
81 static constexpr auto I0 = Number<0>{};
82 static constexpr auto I1 = Number<1>{};
83 static constexpr auto I2 = Number<2>{};
84 static constexpr auto I3 = Number<3>{};
85 static constexpr auto I4 = Number<4>{};
86 static constexpr auto I5 = Number<5>{};
87 static constexpr auto I6 = Number<6>{};
88 static constexpr auto I7 = Number<7>{};
89
90 // K1 should be Number<...>
91 static constexpr auto AK1 = Number<AK1Value>{};
92 static constexpr auto BK1 = Number<BK1Value>{};
93 static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
94 static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
95
97
100
101 __host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1()
102 {
103 // A matrix in LDS memory, dst of blockwise copy
108 AK1,
109 I1));
110 }
111
112 __host__ __device__ static constexpr auto GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1()
113 {
114 // B matrix in LDS memory, dst of blockwise copy
119 BK1,
120 I1));
121 }
122
123 __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
124 {
125 // A matrix in LDS memory, dst of blockwise copy
129 }
130
131 __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
132 {
133 // B matrix in LDS memory, dst of blockwise copy
137 }
138
139 __host__ __device__ static constexpr auto
141 {
142 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
143 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
144
145 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
149 I1,
151
152 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
153 }
154
155 // ck::Tuple<const D0DataType*, const D1DataType*, ...>
156 static constexpr auto MakeDsGridPointer()
157 {
158 return generate_tuple(
159 [&](auto i) {
160 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
161
162 return static_cast<const DDataType*>(nullptr);
163 },
165 }
166
167 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
168 {
169 // LDS allocation for A and B: be careful of alignment
170 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
171 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
172
173 // lds max alignment
174 constexpr auto max_lds_align = math::lcm(AK1, BK1);
175
176 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
177 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
178
179 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
180 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
181
182 // LDS allocation for C shuffle in LDS
183 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
185
186 constexpr auto c_block_size =
187 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
188
189 return math::max(a_block_space_size_aligned * sizeof(ADataType) +
190 b_block_space_size_aligned * sizeof(BDataType),
191 c_block_size * sizeof(CShuffleDataType));
192 }
193
194 __host__ __device__ static auto CalculateMPadded(index_t M)
195 {
196 return math::integer_least_multiple(M, MPerBlock);
197 }
198
199 __host__ __device__ static auto CalculateNPadded(index_t N)
200 {
201 return math::integer_least_multiple(N, NPerBlock);
202 }
203
204 __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch)
205 {
206 return math::integer_least_multiple(K, KPerBlock * K_Batch);
207 }
208
209 template <typename ALayout, GemmSpecialization GemmSpec>
210 __host__ __device__ static auto
212 {
213 const auto a_grid_desc_m_k = [&]() {
215 {
216 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
217 }
219 {
220 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
221 }
222 }();
223
224 const auto MPad = CalculateMPadded(M);
225 const auto KPad = CalculateKPadded(K, KBatch);
226
227 const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
228 a_grid_desc_m_k,
232
233 const auto AK0 = KPad / (KBatch * AK1);
234
239 {
240 // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
242 a_grid_desc_m_kpad,
244 make_right_pad_transform(M, MPad - M)),
247 }
248 else
249 {
251 a_grid_desc_m_kpad,
256 }
257 }
258
259 template <typename BLayout, GemmSpecialization GemmSpec>
260 __host__ __device__ static auto
262 {
263 const auto b_grid_desc_k_n = [&]() {
265 {
266 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
267 }
269 {
270 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
271 }
272 }();
273
274 const auto NPad = CalculateNPadded(N);
275 const auto KPad = CalculateKPadded(K, KBatch);
276
277 const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
278 b_grid_desc_k_n,
282
283 const auto BK0 = KPad / (KBatch * BK1);
284
289 {
290 // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
292 b_grid_desc_kpad_n,
294 make_right_pad_transform(N, NPad - N)),
297 }
298 else
299 {
301 b_grid_desc_kpad_n,
306 }
307 }
308
309 // E desc for destination in blockwise copy
310 template <typename EGridDesc_M_N>
311 __host__ __device__ static constexpr auto
312 MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
313 {
314 const auto M = e_grid_desc_m_n.GetLength(I0);
315 const auto N = e_grid_desc_m_n.GetLength(I1);
316
317 const auto MBlock = M / MPerBlock;
318 const auto NBlock = N / NPerBlock;
319
320 const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
321 e_grid_desc_m_n,
326
327 return e_grid_desc_mblock_mperblock_nblock_nperblock;
328 }
329
330 // Ds desc for source in blockwise copy
331 template <typename DsGridDesc_M_N>
332 __host__ __device__ static constexpr auto
333 MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n)
334 {
335 return generate_tuple(
336 [&](auto i) {
338 },
340 }
341
342 // return block_id to E matrix tile idx (m0, n0) mapping
343 template <typename EGridDesc_M_N>
344 __host__ __device__ static constexpr auto
345 MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
346 {
348 e_grid_desc_m_n);
349 }
350
352
353 template <typename ALayout,
354 typename BLayout,
355 typename DsLayout,
356 typename ELayout,
357 GemmSpecialization GemmSpec>
358 __host__ __device__ static constexpr bool
360 const index_t N,
361 const index_t K,
362 const index_t StrideA,
363 const index_t StrideB,
364 const std::array<index_t, NumDTensor> StrideDs,
365 const index_t StrideE,
366 const index_t KBatch)
367 {
368 const auto a_grid_desc_kbatch_ak0_m_ak1 =
370 const auto b_grid_desc_kbatch_bk0_n_bk1 =
372
373 ignore = StrideDs;
374
375 const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
376
377#if 0
378 // check tile size
379 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
380 {
381 return false;
382 }
383#endif
384
385 // check gridwise gemm pipeline
386 const auto num_k_loop = K / KPerBlock;
387
388 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
389 {
390 return false;
391 }
392
393 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
394 // check tensor size: cannot be larger than 2GB each
395 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
396
397 if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
398 b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
399 e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
400 {
401 return false;
402 }
403
404 return true;
405 }
406
407 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
408 {
409 const index_t num_loop = K / KPerBlock;
410
411 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
412 }
413
414 using DsGridPointer = decltype(MakeDsGridPointer());
415
416 template <typename ELayout, GemmSpecialization GemmSpec>
417 __host__ __device__ static auto
419 {
420 constexpr auto matrix_padder =
422 MPerBlock, NPerBlock, KPerBlock};
423 const auto e_grid_desc_mraw_nraw = [&]() {
425 {
426 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
427 make_tuple(StrideE, I1));
428 }
430 {
431 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
432 make_tuple(I1, StrideE));
433 }
434 }();
435
436 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
437 }
438
439 template <typename DsLayout, GemmSpecialization GemmSpec>
440 __host__ __device__ static auto
441 MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
442 const std::array<index_t, NumDTensor>& NRaws,
443 const std::array<index_t, NumDTensor>& DsStride)
444 {
445 return generate_tuple(
446 [&](auto i) {
447 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
448
449 return MakeEGridDescriptor_M_N<DLayout, GemmSpec>(MRaws[i], NRaws[i], DsStride[i]);
450 },
452 }
453
454 __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
455
456 template <bool HasMainKBlockLoop,
457 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
458 index_t NumDTensor_,
459 typename DsDataType_,
460 typename AGridDesc_KBatch_AK0_M_AK1,
461 typename BGridDesc_KBatch_BK0_N_BK1,
462 typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
463 typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
464 typename CDEElementwiseOperation_,
465 typename Block2ETileMap>
466 __device__ static void Run(const ADataType* __restrict__ p_a_grid,
467 const BDataType* __restrict__ p_b_grid,
468 DsGridPointer p_ds_grid,
469 EDataType* __restrict__ p_e_grid,
470 void* __restrict__ p_shared,
471 uint32_t* barrier_count_finished,
472 const index_t KBatch,
473 const AElementwiseOperation& a_element_op,
474 const BElementwiseOperation& b_element_op,
475 const CDEElementwiseOperation_& cde_element_op,
476 const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1,
477 const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1,
478 const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
479 ds_grid_desc_mblock_mperblock_nblock_nperblock,
480 const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
481 e_grid_desc_mblock_mperblock_nblock_nperblock,
482 const Block2ETileMap& block_2_etile_map)
483 {
484 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
485 p_a_grid, a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize());
486
487 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
488 p_b_grid, b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize());
489
490 const auto ds_grid_buf = generate_tuple(
491 [&](auto i) {
493 p_ds_grid[i],
494 ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
495 },
497
499 p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
500
501 // divide block work by [M, N]
502 const auto block_work_idx =
503 block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
504
505 // HACK: this force m/n_block_data_idx_on_grid into SGPR
506 const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
507
508 const index_t m_block_data_idx_on_grid =
509 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
510
511 const index_t n_block_data_idx_on_grid =
512 __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
513
514 // lds max alignment
515 constexpr auto max_lds_align = math::lcm(AK1, BK1);
516
517 // A matrix in LDS memory, dst of blockwise copy
518 constexpr auto a_block_desc_kbatch_ak0_m_ak1 =
520
521 // B matrix in LDS memory, dst of blockwise copy
522 constexpr auto b_block_desc_kbatch_bk0_n_bk1 =
524
525 // A matrix blockwise copy
526 auto a_blockwise_copy =
528 AElementwiseOperation,
532 ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
533 ABlockTransferThreadClusterArrangeOrder,
534 ADataType,
535 ComputeType,
536 decltype(a_grid_desc_kbatch_ak0_m_ak1),
537 decltype(a_block_desc_kbatch_ak0_m_ak1),
538 ABlockTransferSrcAccessOrder,
540 ABlockTransferSrcVectorDim,
541 3,
542 ABlockTransferSrcScalarPerVector,
543 ABlockTransferDstScalarPerVector_AK1,
544 1,
545 1,
546 AThreadTransferSrcResetCoordinateAfterRun,
547 true,
548 NumGemmKPrefetchStage>(
549 a_grid_desc_kbatch_ak0_m_ak1,
550 make_multi_index(kbatch_id, 0, m_block_data_idx_on_grid, 0),
551 a_element_op,
552 a_block_desc_kbatch_ak0_m_ak1,
553 make_multi_index(0, 0, 0, 0),
555
556 // B matrix blockwise copy
557 auto b_blockwise_copy =
559 BElementwiseOperation,
563 BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
564 BBlockTransferThreadClusterArrangeOrder,
565 BDataType,
566 ComputeType,
567 decltype(b_grid_desc_kbatch_bk0_n_bk1),
568 decltype(b_block_desc_kbatch_bk0_n_bk1),
569 BBlockTransferSrcAccessOrder,
571 BBlockTransferSrcVectorDim,
572 3,
573 BBlockTransferSrcScalarPerVector,
574 BBlockTransferDstScalarPerVector_BK1,
575 1,
576 1,
577 BThreadTransferSrcResetCoordinateAfterRun,
578 true,
579 NumGemmKPrefetchStage>(
580 b_grid_desc_kbatch_bk0_n_bk1,
581 make_multi_index(kbatch_id, 0, n_block_data_idx_on_grid, 0),
582 b_element_op,
583 b_block_desc_kbatch_bk0_n_bk1,
584 make_multi_index(0, 0, 0, 0),
586
587 // A matrix in LDS memory, dst of blockwise copy
588 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
589
590 // B matrix in LDS memory, dst of blockwise copy
591 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
592
593 // GEMM definition
594 // c_mtx += transpose(a_mtx) * b_mtx
595 // a_mtx[K0PerBlock, MPerBlock] is in LDS
596 // b_mtx[K0PerBlock, NPerBlock] is in LDS
597 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
598 // register
599 // sanity check
600 constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
601 constexpr bool is_single_rate_mfma =
603 lcm_AK1_BK1 <= 4) ||
604 (is_same<ComputeType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
606 lcm_AK1_BK1 < 32))
607 ? true
608 : false;
609 constexpr auto is_scale_mfma = false;
610 constexpr index_t KPack = math::max(lcm_AK1_BK1,
611 MfmaSelector<ComputeType,
612 MPerXdl,
613 NPerXdl,
614 ComputeType,
615 is_single_rate_mfma,
616 is_scale_mfma>::selected_mfma.k_per_blk);
617
619 BlockSize,
620 ComputeType,
621 AccDataType,
622 decltype(a_block_desc_ak0_m_ak1),
623 decltype(b_block_desc_bk0_n_bk1),
624 MPerXdl,
625 NPerXdl,
626 MXdlPerWave,
627 NXdlPerWave,
628 KPack,
629 LoopSched>();
630
631#if 1
632 if(block_work_idx[I0] == 0)
633 {
634 const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock;
635 const index_t numNThreads = NPerBlock / nThreadSize;
636 const index_t numMThreads = BlockSize / numNThreads;
637 const index_t mThreadSize = MPerBlock / numMThreads;
638
639 const index_t m_tid = get_thread_local_1d_id() / numNThreads;
640 const index_t n_tid = get_thread_local_1d_id() % numNThreads;
641
642 auto c_thread_desc_mblock_mperblock_nblock_nperblock =
645
647 EDataType,
648 c_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(),
649 true>
650 e_thread_zero_buf;
651
652 auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
653 EDataType,
654 EDataType,
655 decltype(c_thread_desc_mblock_mperblock_nblock_nperblock),
656 decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
660 3,
661 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
663 1,
664 true>{e_grid_desc_mblock_mperblock_nblock_nperblock,
665 make_multi_index(block_work_idx[I1],
666 m_tid * mThreadSize,
667 block_work_idx[I2],
668 n_tid * nThreadSize),
670
671 c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock,
672 make_tuple(I0, I0, I0, I0),
673 e_thread_zero_buf,
674 e_grid_desc_mblock_mperblock_nblock_nperblock,
675 e_grid_buf);
676
677 __syncthreads();
678
679 if(threadIdx.x == 0)
680 {
681 atomicAdd(barrier_count_finished, 1);
682 }
683 }
684#endif
685
686 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
687
688 // LDS allocation for A and B: be careful of alignment
689 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
690 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
691
693 static_cast<ComputeType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
694
696 static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned,
697 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
698
699 constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0);
700 constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock / BK1, 0, 0);
701
702 // gridwise GEMM pipeline
703 const auto gridwise_gemm_pipeline =
705
706 const index_t num_k_block_main_loop =
707 __builtin_amdgcn_readfirstlane((a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
708 a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)) /
709 KPerBlock);
710
711 gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_kbatch_ak0_m_ak1,
712 a_block_desc_kbatch_ak0_m_ak1,
713 a_blockwise_copy,
714 a_grid_buf,
715 a_block_buf,
716 a_block_slice_copy_step,
717 b_grid_desc_kbatch_bk0_n_bk1,
718 b_block_desc_kbatch_bk0_n_bk1,
719 b_blockwise_copy,
720 b_grid_buf,
721 b_block_buf,
722 b_block_slice_copy_step,
723 blockwise_gemm,
724 c_thread_buf,
725 num_k_block_main_loop);
726
727 // shuffle C and write out
728 {
729 if(threadIdx.x == 0)
730 {
731 while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {}
732 }
733
734 __syncthreads();
735
736 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
737 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
738 "wrong!");
739
740 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
741 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
742
743 // TODO: hacky, fix it!
744 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
745 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
746
747 // TODO: hacky, fix it!
748 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
749 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
750 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
751
752 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
753 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
754 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
755 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
756 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
757 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
758 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
759 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
760
761 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
763
764 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
765 static_cast<CShuffleDataType*>(p_shared),
766 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
767
768 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
769 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
773 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
774 M1, // M1 = MWave
775 M2, // M2 * M3 * M4 = MPerXdl
776 M3,
777 M4)),
780 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
781 N1, // N1 = NWave
782 N2))), // N2 = NPerXdl
786
787 // calculate origin of thread output tensor on global memory
788 // blockwise GEMM c matrix starting index
789 const auto c_thread_mtx_on_block =
790 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
791
792 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
793 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
794
795 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
797 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
800
801 const auto m_thread_data_on_block_idx =
802 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
803 make_multi_index(m_thread_data_on_block));
804
805 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
810
811 const auto n_thread_data_on_block_idx =
812 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
813 make_multi_index(n_thread_data_on_block));
814
815 // shuffle: threadwise copy C from VGPR to LDS
816 auto c_thread_copy_vgpr_to_lds =
818 CShuffleDataType,
819 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
820 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
822 Sequence<CShuffleMXdlPerWavePerShuffle,
823 CShuffleNXdlPerWavePerShuffle,
824 I1,
825 I1,
826 M2,
827 I1,
828 M4,
829 I1>,
831 7,
832 1,
834 1,
835 true>{
836 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
838 0,
839 m_thread_data_on_block_idx[I1],
840 n_thread_data_on_block_idx[I1],
841 m_thread_data_on_block_idx[I2],
842 m_thread_data_on_block_idx[I3],
843 m_thread_data_on_block_idx[I4],
844 n_thread_data_on_block_idx[I2]),
846
847 // tuple of reference to C/Ds tensor descriptors
848 const auto c_ds_desc_refs = concat_tuple_of_reference(
849 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
850 generate_tie([&](auto i) -> const auto& // return type should be reference
851 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
853
854 // tuple of reference to C/Ds tensor descriptors
855 const auto c_ds_buf_refs = concat_tuple_of_reference(
856 tie(c_shuffle_block_buf),
857 generate_tie([&](auto i) -> const auto& // return type should be reference
858 { return ds_grid_buf[i]; },
860
861 // tuple of starting index of C/Ds blockwise copy
862 const auto idx_c_ds_block_begin = container_concat(
863 make_tuple(make_multi_index(0, 0, 0, 0)),
865 [&](auto) {
866 return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0);
867 },
869
870 // space filling curve for threadwise C in VGPR before shuffle
871 constexpr auto sfc_c_vgpr =
874 Sequence<CShuffleMXdlPerWavePerShuffle,
875 CShuffleNXdlPerWavePerShuffle,
876 1,
877 1,
878 M2,
879 1,
880 M4,
881 1>>{};
882
883 // space filling curve for shuffled blockwise C/D/E
884 constexpr auto sfc_cde_block =
887 Sequence<1,
888 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
889 1,
890 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
891
892 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
893
894 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
895
896 // blockwise copy C/D/E between LDS and global
897 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
899 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})),
901 decltype(c_ds_desc_refs),
902 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
903 CDEElementwiseOperation_,
904 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
905 // Sequence support
906 // arbitray type
907 Sequence<1,
908 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
909 1,
910 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
911 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
912 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
913 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
914 3, // index_t VectorDim,
915 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
918 uniform_sequence_gen_t<NumDTensor_,
919 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
920 Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
921 {c_ds_desc_refs,
922 idx_c_ds_block_begin,
923 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
924 make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)),
925 cde_element_op};
926
927 static_for<0, num_access, 1>{}([&](auto access_id) {
928 // make sure it's safe to write to LDS
930
931 // each thread write its data from VGPR to LDS
932 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
933 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
934 c_thread_buf,
935 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
936 c_shuffle_block_buf);
937
938 // make sure it's safe to read from LDS
940
941 // each block copy its data from LDS to global
942 cde_block_copy_lds_and_global.Run(
943 c_ds_desc_refs,
944 c_ds_buf_refs,
945 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
946 tie(e_grid_buf));
947
948 if constexpr(access_id < num_access - 1)
949 {
950 constexpr auto cde_lds_and_global_step =
951 sfc_cde_block.GetForwardStep(access_id);
952
953 // move on Ds
954 static_for<0, NumDTensor_, 1>{}([&](auto i) {
955 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
956 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
957 });
958
959 // move on E
960 cde_block_copy_lds_and_global.MoveDstSliceWindow(
961 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
962 I0,
963 cde_lds_and_global_step);
964 }
965 });
966
967 if(threadIdx.x == 0)
968 {
969 index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1);
970
971 if(k_id_finished_t == KBatch)
972 {
973 *barrier_count_finished = 0;
974 }
975 }
976 }
977 }
978
979 template <bool HasMainKBlockLoop,
980 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
981 GemmSpecialization GemmSpec,
982 typename ALayout,
983 typename BLayout,
984 typename DsLayout,
985 typename ELayout,
986 typename Block2ETileMap>
987 __device__ static void Run(const void* __restrict__ p_a_grid_,
988 const void* __restrict__ p_b_grid_,
989 DsGridPointer p_ds_grid,
990 void* __restrict__ p_e_grid_,
991 void* __restrict__ p_shared,
992 uint32_t* barrier_count_finished,
993 const AElementwiseOperation& a_element_op,
994 const BElementwiseOperation& b_element_op,
995 const CDEElementwiseOperation& cde_element_op,
996 const index_t M,
997 const index_t N,
998 const index_t K,
999 const index_t StrideA,
1000 const index_t StrideB,
1001 const std::array<index_t, NumDTensor> StrideDs,
1002 const index_t StrideE,
1003 const index_t KBatch,
1004 const Block2ETileMap& block_2_etile_map)
1005 {
1006 const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
1007 const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
1008 const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
1009
1010 using DsGridDesc_M_N =
1012
1013 DsGridDesc_M_N ds_grid_desc_m_n;
1014
1015 static_for<0, NumDTensor, 1>{}([&](auto j) {
1016 using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
1017
1018 ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout, GemmSpec>(M, N, StrideDs[j]);
1019 });
1020
1021 const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
1022
1023 // tensor descriptors for block/thread-wise copy
1024 const auto a_grid_desc_kbatch_ak0_m_ak1 =
1026
1027 const auto b_grid_desc_kbatch_bk0_n_bk1 =
1029
1030 using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
1032 DsGridDesc_M_N{}))>;
1033
1034 DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
1035
1036 static_for<0, NumDTensor, 1>{}([&](auto j) {
1037 ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
1039 });
1040
1041 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1043
1044 const auto block_work_idx =
1045 block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1046
1047 const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1048
1049 if(kbatch_id == KBatch - 1)
1050 {
1052 p_a_grid,
1053 p_b_grid,
1054 p_ds_grid,
1055 p_e_grid,
1056 p_shared,
1057 barrier_count_finished,
1058 KBatch,
1059 a_element_op,
1060 b_element_op,
1061 cde_element_op,
1062 a_grid_desc_kbatch_ak0_m_ak1,
1063 b_grid_desc_kbatch_bk0_n_bk1,
1064 ds_grid_desc_mblock_mperblock_nblock_nperblock,
1065 e_grid_desc_mblock_mperblock_nblock_nperblock,
1066 block_2_etile_map);
1067 }
1068 else
1069 {
1071 p_a_grid,
1072 p_b_grid,
1073 p_ds_grid,
1074 p_e_grid,
1075 p_shared,
1076 barrier_count_finished,
1077 KBatch,
1078 a_element_op,
1079 b_element_op,
1081 a_grid_desc_kbatch_ak0_m_ak1,
1082 b_grid_desc_kbatch_bk0_n_bk1,
1083 ds_grid_desc_mblock_mperblock_nblock_nperblock,
1084 e_grid_desc_mblock_mperblock_nblock_nperblock,
1085 block_2_etile_map);
1086 }
1087 }
1088};
1089
1090} // namespace ck
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition blockwise_gemm_xdlops.hpp:620
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition utility/sequence.hpp:925
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
unsigned int uint32_t
Definition stdint.h:126
Definition block_to_ctile_map.hpp:261
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:79
static constexpr auto I3
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:87
static constexpr auto BK1
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:95
static constexpr auto I5
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:89
static __device__ void Run(const void *__restrict__ p_a_grid_, const void *__restrict__ p_b_grid_, DsGridPointer p_ds_grid, void *__restrict__ p_e_grid_, void *__restrict__ p_shared, uint32_t *barrier_count_finished, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const index_t M, const index_t N, const index_t K, const index_t StrideA, const index_t StrideB, const std::array< index_t, NumDTensor > StrideDs, const index_t StrideE, const index_t KBatch, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp:987
static constexpr auto BK0PerBlock
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:97
static constexpr auto I7
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:91
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsGridPointer p_ds_grid, EDataType *__restrict__ p_e_grid, void *__restrict__ p_shared, uint32_t *barrier_count_finished, const index_t KBatch, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation_ &cde_element_op, const AGridDesc_KBatch_AK0_M_AK1 &a_grid_desc_kbatch_ak0_m_ak1, const BGridDesc_KBatch_BK0_N_BK1 &b_grid_desc_kbatch_bk0_n_bk1, const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp:466
static constexpr auto I6
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:90
static constexpr auto I2
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:86
static constexpr auto I4
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:88
static constexpr auto I1
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:85
__host__ static __device__ auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:421
static constexpr index_t NumDTensor
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:80
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsGridPointer p_ds_grid, EDataType *__restrict__ p_e_grid, void *__restrict__ p_shared, uint32_t *barrier_count_finished, const index_t KBatch, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation_ &cde_element_op, const AGridDesc_KBatch_AK0_M_AK1 &a_grid_desc_kbatch_ak0_m_ak1, const BGridDesc_KBatch_BK0_N_BK1 &b_grid_desc_kbatch_bk0_n_bk1, const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:470
static constexpr auto AK0PerBlock
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:96
static constexpr auto AK1
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:94
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Definition static_buffer.hpp:16
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v7.hpp:42
Definition threadwise_tensor_slice_transfer.hpp:39
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
__host__ __device__ constexpr auto PadCDescriptor_M_N(const CDesc_MRaw_NRaw &c_desc_mraw_nraw) const
Definition matrix_padder.hpp:163
Definition matrix_padder.hpp:180
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340