gridwise_gemm_multiple_d_xdl_cshuffle.hpp Source File

gridwise_gemm_multiple_d_xdl_cshuffle.hpp Source File#

Composable Kernel: gridwise_gemm_multiple_d_xdl_cshuffle.hpp Source File
gridwise_gemm_multiple_d_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
17
20
21namespace ck {
22
23// GEMM:
24// input : A[M, K]
25// input : B[N, K]
26// input : D0[M, N], D1[M, N], ...
27// output : E[M, N]
28// C = a_op(A) * b_op(B)
29// E = cde_op(C, D0, D1, ...)
30// Assume:
31// D0, D1, ... and E have the same layout
32template <typename ADataType,
33 typename BDataType,
34 typename AComputeDataType_,
35 typename AccDataType,
36 typename CShuffleDataType,
37 typename DsDataType,
38 typename EDataType,
39 typename AElementwiseOperation,
40 typename BElementwiseOperation,
41 typename CDEElementwiseOperation,
42 index_t NumGemmKPrefetchStage,
43 index_t BlockSize,
44 index_t MPerBlock,
45 index_t NPerBlock,
46 index_t KPerBlock,
47 index_t AK1Value,
48 index_t BK1Value,
49 index_t MPerXdl,
50 index_t NPerXdl,
51 index_t MXdlPerWave,
52 index_t NXdlPerWave,
53 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
54 typename ABlockTransferThreadClusterArrangeOrder,
55 typename ABlockTransferSrcAccessOrder,
56 index_t ABlockTransferSrcVectorDim,
57 index_t ABlockTransferSrcScalarPerVector,
58 index_t ABlockTransferDstScalarPerVector_AK1,
59 bool AThreadTransferSrcResetCoordinateAfterRun,
60 index_t ABlockLdsExtraM,
61 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
62 typename BBlockTransferThreadClusterArrangeOrder,
63 typename BBlockTransferSrcAccessOrder,
64 index_t BBlockTransferSrcVectorDim,
65 index_t BBlockTransferSrcScalarPerVector,
66 index_t BBlockTransferDstScalarPerVector_BK1,
67 bool BThreadTransferSrcResetCoordinateAfterRun,
68 index_t BBlockLdsExtraN,
69 index_t CShuffleMXdlPerWavePerShuffle,
70 index_t CShuffleNXdlPerWavePerShuffle,
71 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
72 index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
73 LoopScheduler LoopSched,
75 typename BComputeDataType_ = AComputeDataType_,
76 bool DoElementwiseBeforeCShuffle = false>
78{
79 static constexpr index_t NumDTensor = DsDataType::Size();
80 static_assert(!DoElementwiseBeforeCShuffle || NumDTensor == 0);
81
83
84 static constexpr auto I0 = Number<0>{};
85 static constexpr auto I1 = Number<1>{};
86 static constexpr auto I2 = Number<2>{};
87 static constexpr auto I3 = Number<3>{};
88 static constexpr auto I4 = Number<4>{};
89 static constexpr auto I5 = Number<5>{};
90 static constexpr auto I6 = Number<6>{};
91 static constexpr auto I7 = Number<7>{};
92
93 // K1 should be Number<...>
94 static constexpr auto AK1 = Number<AK1Value>{};
95 static constexpr auto BK1 = Number<BK1Value>{};
96 static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
97 static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
98
100
103
104#if CK_GFX90A_DENORM_WORKAROUND
105 using AComputeDataType =
107 using BComputeDataType =
109#else
114#endif
115
116 __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
117 {
118 // A matrix in LDS memory, dst of blockwise copy
122 }
123
124 __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
125 {
126 // B matrix in LDS memory, dst of blockwise copy
130 }
131
132 __host__ __device__ static constexpr auto
134 {
135 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
136 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
137
138 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
142 I1,
144
145 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
146 }
147
148 // ck::Tuple<const D0DataType*, const D1DataType*, ...>
149 static constexpr auto MakeDsGridPointer()
150 {
151 return generate_tuple(
152 [&](auto i) {
153 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
154
155 return static_cast<const DDataType*>(nullptr);
156 },
158 }
159
160 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
161 {
162 // LDS allocation for A and B: be careful of alignment
163 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
164 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
165
166 // lds max alignment
167 constexpr auto max_lds_align = math::lcm(AK1, BK1);
168
169 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
170 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
171
172 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
173 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
174
175 // LDS allocation for C shuffle in LDS
176 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
178
179 constexpr auto c_block_size =
180 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
181
182 return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) +
183 b_block_space_size_aligned * sizeof(BComputeDataType),
184 c_block_size * sizeof(CShuffleDataType));
185 }
186
187 // A desc for source in blockwise copy
188 template <typename AGridDesc_M_K>
189 __host__ __device__ static constexpr auto
190 MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
191 {
192 const auto M = a_grid_desc_m_k.GetLength(I0);
193 const auto K = a_grid_desc_m_k.GetLength(I1);
194
195 const auto AK0 = K / AK1;
196
197 return transform_tensor_descriptor(a_grid_desc_m_k,
202 }
203
204 // B desc for source in blockwise copy
205 template <typename BGridDesc_N_K>
206 __host__ __device__ static constexpr auto
207 MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
208 {
209 const auto N = b_grid_desc_n_k.GetLength(I0);
210 const auto K = b_grid_desc_n_k.GetLength(I1);
211
212 const auto BK0 = K / BK1;
213
214 return transform_tensor_descriptor(b_grid_desc_n_k,
219 }
220
221 // E desc for destination in blockwise copy
222 template <typename EGridDesc_M_N>
223 __host__ __device__ static constexpr auto
224 MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
225 {
226 const auto M = e_grid_desc_m_n.GetLength(I0);
227 const auto N = e_grid_desc_m_n.GetLength(I1);
228
229 const auto MBlock = M / MPerBlock;
230 const auto NBlock = N / NPerBlock;
231
232 const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
233 e_grid_desc_m_n,
238
239 return e_grid_desc_mblock_mperblock_nblock_nperblock;
240 }
241
242 // Ds desc for source in blockwise copy
243 template <typename DsGridDesc_M_N>
244 __host__ __device__ static constexpr auto
245 MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n)
246 {
247 return generate_tuple(
248 [&](auto i) {
250 },
252 }
253
254 // return block_id to E matrix tile idx (m0, n0) mapping
255 template <typename EGridDesc_M_N>
256 __host__ __device__ static constexpr auto
257 MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
258 {
260 e_grid_desc_m_n);
261 }
262
263 template <typename ALayout, typename BLayout, typename ELayout>
264 __host__ __device__ static bool
266 {
267 // Check if the vector dim is K1 or M|N
268 const auto A_vector_dim_size = ABlockTransferSrcVectorDim == 2 ? KRaw : MRaw;
269 const auto B_vector_dim_size = BBlockTransferSrcVectorDim == 2 ? KRaw : NRaw;
270 const auto E_vector_dim_size = NRaw;
271
272 // check vector load for A tensor
274 {
275 if(!(A_vector_dim_size == KRaw &&
276 A_vector_dim_size % ABlockTransferSrcScalarPerVector == 0))
277 return false;
278 }
280 {
281 if(!(A_vector_dim_size == MRaw &&
282 A_vector_dim_size % ABlockTransferSrcScalarPerVector == 0))
283 return false;
284 }
285 else
286 {
287 return false;
288 }
289
291 {
292 if(!(B_vector_dim_size == NRaw &&
293 B_vector_dim_size % BBlockTransferSrcScalarPerVector == 0))
294 return false;
295 }
297 {
298 if(!(B_vector_dim_size == KRaw &&
299 B_vector_dim_size % BBlockTransferSrcScalarPerVector == 0))
300 return false;
301 }
302 else
303 {
304 return false;
305 }
306
308 {
309 if(!(E_vector_dim_size == NRaw &&
310 E_vector_dim_size % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0))
311 return false;
312 }
314 {
315 if(!(E_vector_dim_size == NRaw &&
316 CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1))
317 return false;
318 }
319 else
320 {
321 return false;
322 }
323
324 return true;
325 }
326
328
329 template <typename AGridDesc_M_K,
330 typename BGridDesc_N_K,
331 typename DsGridDesc_M_N,
332 typename EGridDesc_M_N,
333 typename Block2ETileMap>
334 __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
335 const BGridDesc_N_K& b_grid_desc_n_k,
336 const DsGridDesc_M_N& ds_grid_desc_m_n,
337 const EGridDesc_M_N& e_grid_desc_m_n,
338 [[maybe_unused]] const Block2ETileMap&,
339 index_t k_batch = 1)
340 {
341 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
342 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
343 "Invalid tuning param!");
344
345 static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
346 "KPerBlock must be divisible by AK1Value and BK1Value!");
347
348 const auto M = a_grid_desc_m_k.GetLength(I0);
349 const auto N = b_grid_desc_n_k.GetLength(I0);
350 const auto AK = a_grid_desc_m_k.GetLength(I1);
351 const auto BK = b_grid_desc_n_k.GetLength(I1);
352
353 // check consistency of desc
354 if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK))
355 {
356 return false;
357 }
358 bool valid = true;
359
360 static_for<0, NumDTensor, 1>{}([&](auto i) {
361 valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
362 N == ds_grid_desc_m_n[i].GetLength(I1));
363 });
364
365 if(!valid)
366 {
367 return false;
368 }
369
370 // check tile size
371 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0))
372 {
373 return false;
374 }
375
376 // check gridwise gemm pipeline
377 const auto num_k_loop = AK / (KPerBlock * k_batch);
378 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
379 {
380 return false;
381 }
382
383 // check block-to-E-tile
384 // if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
385 //{
386 // return false;
387 //}
388
389 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
390 // check tensor size: cannot be larger than 2GB each
391 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
392
393 if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
394 b_grid_desc_n_k.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
395 e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
396 {
397 return false;
398 }
399
400 return true;
401 }
402
403 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K,
404 index_t k_batch = 1)
405 {
406 const index_t num_loop = K / (KPerBlock * k_batch);
407
408 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
409 }
410
411 using DsGridPointer = decltype(MakeDsGridPointer());
412
413 template <typename ALayout, GemmSpecialization GemmSpec>
414 __host__ __device__ static auto
416 {
417 constexpr auto matrix_padder =
419 MPerBlock, NPerBlock, KPerBlock};
420
421 const auto a_grid_desc_mraw_kraw = [&]() {
423 {
424 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
425 make_tuple(StrideA, I1));
426 }
428 {
429 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
430 make_tuple(I1, StrideA));
431 }
432 }();
433
434 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
435 }
436
437 template <typename BLayout, GemmSpecialization GemmSpec>
438 __host__ __device__ static auto
440 {
441 constexpr auto matrix_padder =
443 MPerBlock, NPerBlock, KPerBlock};
444
445 const auto b_grid_desc_nraw_kraw = [&]() {
447 {
448 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
449 make_tuple(I1, StrideB));
450 }
452 {
453 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
454 make_tuple(StrideB, I1));
455 }
456 }();
457
458 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
459 }
460
461 template <typename ELayout, GemmSpecialization GemmSpec>
462 __host__ __device__ static auto
464 {
465 constexpr auto matrix_padder =
467 MPerBlock, NPerBlock, KPerBlock};
468 const auto e_grid_desc_mraw_nraw = [&]() {
470 {
471 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
472 make_tuple(StrideE, I1));
473 }
475 {
476 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
477 make_tuple(I1, StrideE));
478 }
479 }();
480
481 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
482 }
483
484#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC)
485 template <typename DsLayout, GemmSpecialization GemmSpec>
486 __host__ __device__ static auto
489 const ck::Array<index_t, NumDTensor>& DsStride)
490#else
491 template <typename DsLayout, GemmSpecialization GemmSpec>
492 __host__ __device__ static auto
493 MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
494 const std::array<index_t, NumDTensor>& NRaws,
495 const std::array<index_t, NumDTensor>& DsStride)
496#endif
497
498 {
499 return generate_tuple(
500 [&](auto i) {
501 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
502
503 return MakeEGridDescriptor_M_N<DLayout, GemmSpec>(MRaws[i], NRaws[i], DsStride[i]);
504 },
506 }
507
508 __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
509
510 template <bool HasMainKBlockLoop,
511 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
512 typename AGridDesc_AK0_M_AK1,
513 typename BGridDesc_BK0_N_BK1,
514 typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
515 typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
516 typename Block2ETileMap>
517 __device__ static void Run(const ADataType* __restrict__ p_a_grid,
518 const BDataType* __restrict__ p_b_grid,
519 DsGridPointer p_ds_grid,
520 EDataType* __restrict__ p_e_grid,
521 void* __restrict__ p_shared,
522 const AElementwiseOperation& a_element_op,
523 const BElementwiseOperation& b_element_op,
524 const CDEElementwiseOperation& cde_element_op,
525 const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
526 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
527 const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
528 ds_grid_desc_mblock_mperblock_nblock_nperblock,
529 const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
530 e_grid_desc_mblock_mperblock_nblock_nperblock,
531 const Block2ETileMap& block_2_etile_map,
532 const index_t k_batch = 1,
533 const index_t k_idx = 0)
534 {
535 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
536 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
537
538 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
539 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
540
541 const auto ds_grid_buf = generate_tuple(
542 [&](auto i) {
544 p_ds_grid[i],
545 ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
546 },
548
550 p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
551
552 // divide block work by [M, N]
553 const auto block_work_idx =
554 block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
555
556 if(!block_2_etile_map.ValidCTileIndex(
557 block_work_idx,
558 make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
559 e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
560 {
561 return;
562 }
563
564 const index_t num_ak0_per_block =
565 __builtin_amdgcn_readfirstlane(a_grid_desc_ak0_m_ak1.GetLength(I0) / k_batch);
566 const index_t num_bk0_per_block =
567 __builtin_amdgcn_readfirstlane(b_grid_desc_bk0_n_bk1.GetLength(I0) / k_batch);
568 // HACK: this force m/n_block_data_idx_on_grid into SGPR
569 const index_t m_block_data_idx_on_grid =
570 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
571
572 const index_t n_block_data_idx_on_grid =
573 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
574
575 // lds max alignment
576 constexpr auto max_lds_align = math::lcm(AK1, BK1);
577
578 // A matrix in LDS memory, dst of blockwise copy
579 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
580
581 // B matrix in LDS memory, dst of blockwise copy
582 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
583
584 // A matrix blockwise copy
585 auto a_blockwise_copy =
587 AElementwiseOperation,
591 ABlockTransferThreadClusterLengths_AK0_M_AK1,
592 ABlockTransferThreadClusterArrangeOrder,
593 ADataType,
595 decltype(a_grid_desc_ak0_m_ak1),
596 decltype(a_block_desc_ak0_m_ak1),
597 ABlockTransferSrcAccessOrder,
599 ABlockTransferSrcVectorDim,
600 2,
601 ABlockTransferSrcScalarPerVector,
602 ABlockTransferDstScalarPerVector_AK1,
603 1,
604 1,
605 AThreadTransferSrcResetCoordinateAfterRun,
606 true,
607 NumGemmKPrefetchStage>(
608 a_grid_desc_ak0_m_ak1,
609 make_multi_index(num_ak0_per_block * k_idx, m_block_data_idx_on_grid, 0),
610 a_element_op,
611 a_block_desc_ak0_m_ak1,
612 make_multi_index(0, 0, 0),
614
615 // B matrix blockwise copy
616 auto b_blockwise_copy =
618 BElementwiseOperation,
622 BBlockTransferThreadClusterLengths_BK0_N_BK1,
623 BBlockTransferThreadClusterArrangeOrder,
624 BDataType,
626 decltype(b_grid_desc_bk0_n_bk1),
627 decltype(b_block_desc_bk0_n_bk1),
628 BBlockTransferSrcAccessOrder,
630 BBlockTransferSrcVectorDim,
631 2,
632 BBlockTransferSrcScalarPerVector,
633 BBlockTransferDstScalarPerVector_BK1,
634 1,
635 1,
636 BThreadTransferSrcResetCoordinateAfterRun,
637 true,
638 NumGemmKPrefetchStage>(
639 b_grid_desc_bk0_n_bk1,
640 make_multi_index(num_bk0_per_block * k_idx, n_block_data_idx_on_grid, 0),
641 b_element_op,
642 b_block_desc_bk0_n_bk1,
643 make_multi_index(0, 0, 0),
645
646 // GEMM definition
647 // c_mtx += transpose(a_mtx) * b_mtx
648 // a_mtx[K0PerBlock, MPerBlock] is in LDS
649 // b_mtx[K0PerBlock, NPerBlock] is in LDS
650 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
651 // register
652 // sanity check
653 constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
654 constexpr bool is_single_rate_mfma =
657 lcm_AK1_BK1 <= 4) ||
658 (is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
660 lcm_AK1_BK1 < 32))
661 ? true
662 : false;
663 constexpr auto is_scale_mfma = false;
664 constexpr index_t KPack = math::max(lcm_AK1_BK1,
665 MfmaSelector<AComputeDataType_,
666 MPerXdl,
667 NPerXdl,
668 BComputeDataType_,
669 is_single_rate_mfma,
670 is_scale_mfma>::selected_mfma.k_per_blk);
672 BlockSize,
675 AccDataType,
676 decltype(a_block_desc_ak0_m_ak1),
677 decltype(b_block_desc_bk0_n_bk1),
678 MPerXdl,
679 NPerXdl,
680 MXdlPerWave,
681 NXdlPerWave,
682 KPack,
683 LoopSched,
684 AComputeDataType_,
685 BComputeDataType_>();
686
687 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
688
689 // LDS allocation for A and B: be careful of alignment
690 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
691 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
692
694 static_cast<AComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
695
697 static_cast<BComputeDataType*>(p_shared) + a_block_space_size_aligned,
698 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
699
700 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
701 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
702
703 // gridwise GEMM pipeline
704 const auto gridwise_gemm_pipeline =
706
707 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
708 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
709 (KPerBlock * k_batch));
710
711 gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
712 a_block_desc_ak0_m_ak1,
713 a_blockwise_copy,
714 a_grid_buf,
715 a_block_buf,
716 a_block_slice_copy_step,
717 b_grid_desc_bk0_n_bk1,
718 b_block_desc_bk0_n_bk1,
719 b_blockwise_copy,
720 b_grid_buf,
721 b_block_buf,
722 b_block_slice_copy_step,
723 blockwise_gemm,
724 c_thread_buf,
725 num_k_block_main_loop);
726
727 // shuffle C and write out
728 {
729 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
730 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
731 "wrong!");
732
733 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
734 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
735
736 // TODO: hacky, fix it!
737 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
738 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
739
740 // TODO: hacky, fix it!
741 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
742 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
743 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
744
745 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
746 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
747 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
748 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
749 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
750 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
751 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
752 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
753
754 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
756
757 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
758 static_cast<CShuffleDataType*>(p_shared),
759 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
760
761 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
762 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
766 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
767 M1, // M1 = MWave
768 M2, // M2 * M3 * M4 = MPerXdl
769 M3,
770 M4)),
773 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
774 N1, // N1 = NWave
775 N2))), // N2 = NPerXdl
779
780 // calculate origin of thread output tensor on global memory
781 // blockwise GEMM c matrix starting index
782 const auto c_thread_mtx_on_block =
783 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
784
785 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
786 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
787
788 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
790 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
793
794 const auto m_thread_data_on_block_idx =
795 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
796 make_multi_index(m_thread_data_on_block));
797
798 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
803
804 const auto n_thread_data_on_block_idx =
805 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
806 make_multi_index(n_thread_data_on_block));
807
809 const auto& vpgr_to_lds_element_op = [&] {
810 if constexpr(DoElementwiseBeforeCShuffle)
811 {
812 return cde_element_op;
813 }
814 else
815 {
816 return pass_through;
817 }
818 };
819 const auto& lds_to_global_element_op = [&] {
820 if constexpr(!DoElementwiseBeforeCShuffle)
821 {
822 return cde_element_op;
823 }
824 else
825 {
826 return pass_through;
827 }
828 };
829
830 // shuffle: threadwise copy C from VGPR to LDS
831 auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
832 AccDataType,
833 CShuffleDataType,
834 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
835 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
836 conditional_t<DoElementwiseBeforeCShuffle,
837 CDEElementwiseOperation,
839 Sequence<CShuffleMXdlPerWavePerShuffle,
840 CShuffleNXdlPerWavePerShuffle,
841 I1,
842 I1,
843 M2,
844 I1,
845 M4,
846 I1>,
848 7,
849 1,
851 1,
852 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
854 0,
855 m_thread_data_on_block_idx[I1],
856 n_thread_data_on_block_idx[I1],
857 m_thread_data_on_block_idx[I2],
858 m_thread_data_on_block_idx[I3],
859 m_thread_data_on_block_idx[I4],
860 n_thread_data_on_block_idx[I2]),
861 vpgr_to_lds_element_op()};
862
863 // tuple of reference to C/Ds tensor descriptors
864 const auto c_ds_desc_refs = concat_tuple_of_reference(
865 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
866 generate_tie([&](auto i) -> const auto& // return type should be reference
867 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
869
870 // tuple of reference to C/Ds tensor descriptors
871 const auto c_ds_buf_refs = concat_tuple_of_reference(
872 tie(c_shuffle_block_buf),
873 generate_tie([&](auto i) -> const auto& // return type should be reference
874 { return ds_grid_buf[i]; },
876
877 // tuple of starting index of C/Ds blockwise copy
878 const auto idx_c_ds_block_begin = container_concat(
879 make_tuple(make_multi_index(0, 0, 0, 0)),
881 [&](auto) {
882 return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
883 },
885
886 // blockwise copy C/D/E between LDS and global
887 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
889 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
891 decltype(c_ds_desc_refs),
892 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
893 conditional_t<!DoElementwiseBeforeCShuffle,
894 CDEElementwiseOperation,
896 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
897 // support arbitray type
898 Sequence<1,
899 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
900 1,
901 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
902 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
903 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
904 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
905 3, // index_t VectorDim,
906 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
910 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
911 Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
912 {c_ds_desc_refs,
913 idx_c_ds_block_begin,
914 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
915 make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
916 lds_to_global_element_op()};
917
918 // space filling curve for threadwise C in VGPR before shuffle
919 constexpr auto sfc_c_vgpr =
922 Sequence<CShuffleMXdlPerWavePerShuffle,
923 CShuffleNXdlPerWavePerShuffle,
924 1,
925 1,
926 M2,
927 1,
928 M4,
929 1>>{};
930
931 // space filling curve for shuffled blockwise C/D/E
932 constexpr auto sfc_cde_block =
935 Sequence<1,
936 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
937 1,
938 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
939
940 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
941
942 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
943
944 static_for<0, num_access, 1>{}([&](auto access_id) {
945 // make sure it's safe to write to LDS
947
948 // each thread write its data from VGPR to LDS
949 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
950 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
951 c_thread_buf,
952 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
953 c_shuffle_block_buf);
954
955 // make sure it's safe to read from LDS
957
958 // each block copy its data from LDS to global
959 cde_block_copy_lds_and_global.Run(
960 c_ds_desc_refs,
961 c_ds_buf_refs,
962 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
963 tie(e_grid_buf));
964
965 if constexpr(access_id < num_access - 1)
966 {
967 constexpr auto cde_lds_and_global_step =
968 sfc_cde_block.GetForwardStep(access_id);
969
970 // move on Ds
971 static_for<0, NumDTensor, 1>{}([&](auto i) {
972 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
973 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
974 });
975
976 // move on E
977 cde_block_copy_lds_and_global.MoveDstSliceWindow(
978 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
979 I0,
980 cde_lds_and_global_step);
981 }
982 });
983 }
984 }
985
986 template <bool HasMainKBlockLoop,
987 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
988 GemmSpecialization GemmSpec,
989 typename ALayout,
990 typename BLayout,
991 typename DsLayout,
992 typename ELayout,
993 typename Block2ETileMap>
994 __device__ static void Run(const void* __restrict__ p_a_grid_,
995 const void* __restrict__ p_b_grid_,
996 DsGridPointer p_ds_grid,
997 void* __restrict__ p_e_grid_,
998 void* __restrict__ p_shared,
999 const AElementwiseOperation& a_element_op,
1000 const BElementwiseOperation& b_element_op,
1001 const CDEElementwiseOperation& cde_element_op,
1002 const index_t M,
1003 const index_t N,
1004 const index_t K,
1005 const index_t StrideA,
1006 const index_t StrideB,
1007#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC)
1008 const ck::Array<index_t, NumDTensor> StrideDs,
1009#else
1010 const std::array<index_t, NumDTensor> StrideDs,
1011#endif
1012 const index_t StrideE,
1013 const Block2ETileMap& block_2_etile_map)
1014 {
1015 const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
1016 const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
1017 const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
1018
1019 // tensor descriptors for problem definiton
1020 const auto a_grid_desc_m_k = MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
1021 const auto b_grid_desc_n_k = MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
1022
1023 using DsGridDesc_M_N =
1025
1026 DsGridDesc_M_N ds_grid_desc_m_n;
1027
1028 static_for<0, NumDTensor, 1>{}([&](auto j) {
1029 using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
1030
1031 ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout, GemmSpec>(M, N, StrideDs[j]);
1032 });
1033
1034 const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
1035
1036 // tensor descriptors for block/thread-wise copy
1037 const auto a_grid_desc_ak0_m_ak1 = MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
1038
1039 const auto b_grid_desc_bk0_n_bk1 = MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
1040
1041 using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
1043 DsGridDesc_M_N{}))>;
1044
1045 DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
1046
1047 static_for<0, NumDTensor, 1>{}([&](auto j) {
1048 ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
1050 });
1051
1052 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1054
1056 p_a_grid,
1057 p_b_grid,
1058 p_ds_grid,
1059 p_e_grid,
1060 p_shared,
1061 a_element_op,
1062 b_element_op,
1063 cde_element_op,
1064 a_grid_desc_ak0_m_ak1,
1065 b_grid_desc_bk0_n_bk1,
1066 ds_grid_desc_mblock_mperblock_nblock_nperblock,
1067 e_grid_desc_mblock_mperblock_nblock_nperblock,
1068 block_2_etile_map);
1069 }
1070
1071 template <bool HasMainKBlockLoop,
1072 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
1073 typename AGridDesc_MK,
1074 typename BGridDesc_NK,
1075 typename DsGridDesc_MN,
1076 typename EGridDesc_MN,
1077 typename Block2ETileMap>
1078 __device__ static void Run(const void* __restrict__ p_a_grid_,
1079 const void* __restrict__ p_b_grid_,
1080 DsGridPointer p_ds_grid,
1081 void* __restrict__ p_e_grid_,
1082 void* __restrict__ p_shared,
1083 const AElementwiseOperation& a_element_op,
1084 const BElementwiseOperation& b_element_op,
1085 const CDEElementwiseOperation& cde_element_op,
1086 const AGridDesc_MK& a_grid_desc_m_k,
1087 const BGridDesc_NK& b_grid_desc_n_k,
1088 const DsGridDesc_MN& ds_grid_desc_m_n,
1089 const EGridDesc_MN& e_grid_desc_m_n,
1090 const Block2ETileMap& block_2_etile_map)
1091 {
1092 const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
1093 const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
1094 const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
1095
1096 // tensor descriptors for block/thread-wise copy
1097 const auto a_grid_desc_ak0_m_ak1 = MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
1098 const auto b_grid_desc_bk0_n_bk1 = MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
1099
1100 using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
1102 DsGridDesc_MN{}))>;
1103
1104 DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
1105
1106 static_for<0, NumDTensor, 1>{}([&](auto j) {
1107 ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
1109 });
1110
1111 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1113
1115 p_a_grid,
1116 p_b_grid,
1117 p_ds_grid,
1118 p_e_grid,
1119 p_shared,
1120 a_element_op,
1121 b_element_op,
1122 cde_element_op,
1123 a_grid_desc_ak0_m_ak1,
1124 b_grid_desc_bk0_n_bk1,
1125 ds_grid_desc_mblock_mperblock_nblock_nperblock,
1126 e_grid_desc_mblock_mperblock_nblock_nperblock,
1127 block_2_etile_map);
1128 }
1129};
1130
1131} // namespace ck
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition ck.hpp:268
ushort bhalf_t
Definition data_type.hpp:30
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition blockwise_gemm_xdlops.hpp:620
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition utility/sequence.hpp:925
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
Definition utility/array.hpp:14
Definition block_to_ctile_map.hpp:261
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
__host__ static __device__ auto MakeDsGridDescriptor_M_N(const std::array< index_t, NumDTensor > &MRaws, const std::array< index_t, NumDTensor > &NRaws, const std::array< index_t, NumDTensor > &DsStride)
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:493
static __device__ void Run(const void *__restrict__ p_a_grid_, const void *__restrict__ p_b_grid_, DsGridPointer p_ds_grid, void *__restrict__ p_e_grid_, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const AGridDesc_MK &a_grid_desc_m_k, const BGridDesc_NK &b_grid_desc_n_k, const DsGridDesc_MN &ds_grid_desc_m_n, const EGridDesc_MN &e_grid_desc_m_n, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:1078
static __device__ void Run(const void *__restrict__ p_a_grid_, const void *__restrict__ p_b_grid_, DsGridPointer p_ds_grid, void *__restrict__ p_e_grid_, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const index_t M, const index_t N, const index_t K, const index_t StrideA, const index_t StrideB, const std::array< index_t, NumDTensor > StrideDs, const index_t StrideE, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:994
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsGridPointer p_ds_grid, EDataType *__restrict__ p_e_grid, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap &block_2_etile_map, const index_t k_batch=1, const index_t k_idx=0)
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:517
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v7.hpp:42
Definition threadwise_tensor_slice_transfer.hpp:39
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
__host__ __device__ constexpr auto PadBDescriptor_N_K(const BDesc_NRaw_KRaw &b_desc_nraw_kraw) const
Definition matrix_padder.hpp:155
__host__ __device__ constexpr auto PadCDescriptor_M_N(const CDesc_MRaw_NRaw &c_desc_mraw_nraw) const
Definition matrix_padder.hpp:163
__host__ __device__ constexpr auto PadADescriptor_M_K(const ADesc_MRaw_KRaw &a_desc_mraw_kraw) const
Definition matrix_padder.hpp:147
Definition matrix_padder.hpp:180
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340