blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp Source File

blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp Source File
blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8#define DS_READ_A_PREFETCH_STAGES 2
9
10template <typename T>
11constexpr auto compute_stage_loads(T total_loads, T stages)
12{
13 return std::make_pair((total_loads + stages - 1) / stages, // ceil
14 total_loads / stages // floor
15 );
16}
17
18namespace ck {
19
20// Compute optimized pipeline
21// GlobalPrefetchStages: 2
22// LocalPreFillStages: 1
23// LocalPreFetchStages: 1
24// LocalSharedMemoryBuffer: 1
25
26template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
27 index_t BlockSize,
28 typename ADataType,
29 typename BDataType,
30 typename ComputeDataType,
31 typename AccDataType,
32 typename ATileDesc,
33 typename BTileDesc,
34 typename AMmaTileDesc,
35 typename BMmaTileDesc,
36 index_t ABlockTransferSrcScalarPerVector,
37 index_t BBlockTransferSrcScalarPerVector,
38 index_t MPerBlock,
39 index_t NPerBlock,
40 index_t KPerBlock,
41 index_t MPerXDL,
42 index_t NPerXDL,
43 index_t MRepeat,
44 index_t NRepeat,
45 index_t KPacks>
49
50template <index_t BlockSize,
51 typename ADataType,
52 typename BDataType,
53 typename ComputeDataType,
54 typename AccDataType,
55 typename ATileDesc,
56 typename BTileDesc,
57 typename AMmaTileDesc,
58 typename BMmaTileDesc,
59 index_t ABlockTransferSrcScalarPerVector,
60 index_t BBlockTransferSrcScalarPerVector,
61 index_t MPerBlock,
62 index_t NPerBlock,
63 index_t KPerBlock,
64 index_t MPerXDL,
65 index_t NPerXDL,
66 index_t MRepeat,
67 index_t NRepeat,
68 index_t KPack
69 // ,bool TransposeC //disable transposec right now...
70 >
72 BlockSize,
73 ADataType,
74 BDataType,
75 ComputeDataType,
76 AccDataType,
77 ATileDesc,
78 BTileDesc,
79 AMmaTileDesc,
80 BMmaTileDesc,
81 ABlockTransferSrcScalarPerVector,
82 BBlockTransferSrcScalarPerVector,
83 MPerBlock,
84 NPerBlock,
85 KPerBlock,
86 MPerXDL,
87 NPerXDL,
88 MRepeat,
89 NRepeat,
90 KPack>
92 ADataType,
93 BDataType,
94 ComputeDataType,
95 AccDataType,
96 ATileDesc,
97 BTileDesc,
98 AMmaTileDesc,
99 BMmaTileDesc,
100 ABlockTransferSrcScalarPerVector,
101 BBlockTransferSrcScalarPerVector,
102 MPerBlock,
103 NPerBlock,
104 KPerBlock,
105 MPerXDL,
106 NPerXDL,
107 MRepeat,
108 NRepeat,
109 KPack>
110
111{
113 ADataType,
114 BDataType,
115 ComputeDataType,
116 AccDataType,
117 ATileDesc,
118 BTileDesc,
119 AMmaTileDesc,
120 BMmaTileDesc,
121 ABlockTransferSrcScalarPerVector,
122 BBlockTransferSrcScalarPerVector,
123 MPerBlock,
124 NPerBlock,
125 KPerBlock,
126 MPerXDL,
127 NPerXDL,
128 MRepeat,
129 NRepeat,
130 KPack>;
131 using Base::A_K1;
132 using Base::B_K1;
133 using Base::I0;
134 using Base::I1;
135 using Base::I2;
136 using Base::KGroup;
137 using Base::KRepeat;
138 using Base::xdlops_gemm;
139 using typename Base::HotLoopInstList;
140
153 using Base::MWaves;
154 using Base::WaveSize;
155
156 static constexpr index_t PrefetchStages = 2;
157 static constexpr index_t PrefillStages = 1;
158 static constexpr index_t GlobalBufferNum = 1;
159 static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
160
161 template <typename TileDesc_M0_M1_M2_K>
162 __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
163 {
164 constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
165 constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
166 constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
167 constexpr index_t K2 = KPack / KGroup;
168 constexpr index_t K1 = WaveSize / NPerXDL;
169 constexpr index_t K0 = KRepeat * KGroup;
170
172 TileDesc_M0_M1_M2_K{},
180 }
181
182 static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
184
185 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
186 {
187 return num_loop > PrefetchStages;
188 }
189
190 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
191 {
192 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
193 }
194
195 __device__ static constexpr auto HotLoopScheduler()
196 {
197 constexpr auto num_ds_read_inst_a =
198 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
201
202 constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
203
204 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
205 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
206
207 static_assert(num_buffer_load_inst_a == num_ds_write_inst_a);
208
209 constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
210 constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
211
212 constexpr auto ds_read_a_issue_cycle =
213 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
214 constexpr auto ds_read_a_mfma_rate =
215 math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle);
216
217 constexpr auto num_total_stages = MRepeat;
218
219 // Group num_mfma_perstage num_ds_read_a_perstage
220 // since we want to reuse a local register buffer
221 constexpr auto num_mfma_perstage = num_mfma_inst / MRepeat;
222 constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / MRepeat;
223
224 constexpr auto num_ds_read_a_mfma_perstage =
225 math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
226
227 constexpr auto total_buffer_loads = num_buffer_load_inst_a + num_buffer_load_inst_b;
228 constexpr auto stages_available = MRepeat - DS_READ_A_PREFETCH_STAGES;
229
230 constexpr auto stage_loads = compute_stage_loads(total_buffer_loads, stages_available);
231
232 constexpr auto buffer_load_perstage_more = stage_loads.first;
233 constexpr auto buffer_load_perstage_less = stage_loads.second;
234
235 constexpr auto buffer_load_stages_more = total_buffer_loads % stages_available;
236
237 constexpr auto buffer_b_heavy_loads = buffer_load_perstage_more * buffer_load_stages_more;
238 constexpr auto buffer_b_remaining =
239 num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more;
240
241 constexpr auto buffer_load_b_stages =
242 buffer_b_heavy_loads > num_buffer_load_inst_b
243 ? num_buffer_load_inst_b / buffer_load_perstage_more
244 : (buffer_load_stages_more + buffer_b_remaining / buffer_load_perstage_less);
245
246 constexpr auto buffer_load_a_stages =
247 num_total_stages - DS_READ_A_PREFETCH_STAGES - buffer_load_b_stages;
248
249 static_assert(buffer_load_a_stages > 0,
250 "The buffer load a stages should always have a value over 0.");
251
252 constexpr auto buffer_load_issue_point_interval_more =
253 math::integer_divide_ceil(num_mfma_perstage, buffer_load_perstage_more);
254 constexpr auto buffer_load_issue_point_interval_less =
255 buffer_load_perstage_less == 0
256 ? INT32_MAX
257 : math::integer_divide_ceil(num_mfma_perstage, buffer_load_perstage_less);
258 constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
259
260 // B global read
262 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
263 __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0);
264
265 if constexpr(((i < buffer_load_stages_more) &&
266 (imfma % buffer_load_issue_point_interval_more == 0)) ||
267 ((i >= buffer_load_stages_more) &&
268 (imfma % buffer_load_issue_point_interval_less == 0)))
269 {
270 __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0);
271 }
272
273 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
274 {
275 __builtin_amdgcn_sched_group_barrier(
276 SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0);
277 }
278 });
279 });
280
281 // A global read + A local write
283 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
284 __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0);
285 if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
286 (imfma % buffer_load_issue_point_interval_more == 0)) ||
287 (((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
288 (imfma % buffer_load_issue_point_interval_less == 0)))
289 {
290 __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_LDS_WRITE, 1, 0);
291 }
292 if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
293 (imfma % buffer_load_issue_point_interval_more ==
294 buffer_load_issue_point_a)) ||
295 (((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
296 (imfma % buffer_load_issue_point_interval_less ==
297 buffer_load_issue_point_a)))
298 {
299 __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0);
300 }
301 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
302 {
303 __builtin_amdgcn_sched_group_barrier(
304 SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0);
305 }
306 });
307 });
308
309 // lds synchronization, prefetch next loop local A
311 ignore = i;
312 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
313 __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0);
314 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
315 {
316 __builtin_amdgcn_sched_group_barrier(
317 SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0);
318 }
319 });
320 });
321 }
322
323 template <bool HasMainLoop,
324 TailNumber TailNum,
325 typename AGridDesc,
326 typename ABlockDesc,
327 typename ABlockTransfer,
328 typename AGridBuffer,
329 typename ABlockBuffer,
330 typename ABlockTransferStep,
331 typename BGridDesc,
332 typename BBlockTransfer,
333 typename BGridBuffer,
334 typename BBlockBuffer,
335 typename BBlockTransferStep,
336 typename CThreadBuffer>
337 __device__ void Run(const AGridDesc& a_grid_desc,
338 const ABlockDesc& a_block_desc,
339 ABlockTransfer& a_blockwise_copy,
340 const AGridBuffer& a_grid_buf,
341 ABlockBuffer& a_block_buf,
342 const ABlockTransferStep& a_block_copy_step,
343 const BGridDesc& b_grid_desc,
344 BBlockTransfer& b_blockwise_copy,
345 const BGridBuffer& b_grid_buf,
346 BBlockBuffer& b_block_buf,
347 const BBlockTransferStep& b_block_copy_step,
348 CThreadBuffer& c_thread_buf,
349 index_t num_loop) const
350 {
351 ignore = b_block_buf;
352 __builtin_amdgcn_sched_barrier(0);
354 a_thread_desc_.GetElementSpaceSize());
356 b_thread_desc_.GetElementSpaceSize());
357
358 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
359 constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
360
361 // Global prefetch A1 B1
362 b_blockwise_copy.Run(b_grid_desc,
363 b_grid_buf,
365 b_block_origin_idx,
366 b_thread_bufs(I0));
367 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
368
369 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
370 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
371 __builtin_amdgcn_sched_barrier(0);
372
373 // Local prefill A1
374 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
375
376 // Global prefetch A2
377 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
378 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
379
380 // Local prefetch A1
383 static_for<0, KRepeat, 1>{}([&](auto k0) {
384 static_for<0, KGroup, 1>{}([&](auto kg0) {
387 a_block_buf.At(I0),
389 make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
390 a_thread_buf);
391 });
392 });
393 });
394
395 // Initialize C
396 c_thread_buf.Clear();
397
398 __builtin_amdgcn_sched_barrier(0);
399
400 // main body
401 if constexpr(HasMainLoop)
402 {
403 index_t i = 0;
404 do
405 {
406 auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
407 b_blockwise_copy.Run(b_grid_desc,
408 b_grid_buf,
410 b_block_origin_idx,
411 b_thread_bufs(local_read_buf));
412 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
413
414 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf));
415 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
416 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
417
418 static_for<0, MRepeat, 1>{}([&](auto m0) {
419 static_for<0, KRepeat, 1>{}([&](auto k0) {
420 static_for<0, NRepeat, 1>{}([&](auto n0) {
423
424 static_for<0, KPack, 1>{}([&](auto ik) {
425 a_thread_vec.template AsType<ComputeDataType>()(ik) =
426 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
427 make_tuple((m0 + HotloopLocalBufSwitch * mfma_reg_buf) %
428 2,
429 I0,
430 I0,
431 k0,
432 I0,
433 ik))>{}];
434 b_thread_vec.template AsType<ComputeDataType>()(ik) =
435 b_thread_bufs[mfma_reg_buf]
436 [Number<b_thread_desc_.CalculateOffset(
437 make_tuple(n0, I0, k0, ik))>{}];
438 });
439
440 using mfma_input_type =
441 typename vector_type<ComputeDataType,
442 xdlops_gemm.K1PerXdlops>::type;
443
444 constexpr index_t c_offset =
445 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
446
447 xdlops_gemm.Run(
448 a_thread_vec.template AsType<mfma_input_type>(),
449 b_thread_vec.template AsType<mfma_input_type>(),
450 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
451 });
452 });
453
454 if constexpr(m0.value == (MRepeat - 2))
455 {
457
458 static_for<0, KRepeat, 1>{}([&](auto k0) {
459 static_for<0, KGroup, 1>{}([&](auto kg0) {
460 a_thread_copy_.Run(
463 I0,
464 I0,
466 I0,
467 I0),
468 a_block_buf.At(local_read_buf),
471 Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
472 2>{},
473 I0,
474 I0,
475 k0,
476 I0,
478 a_thread_buf);
479 });
480 });
481 }
482 else if constexpr(m0.value == (MRepeat - 1))
483 {
484 static_for<0, KRepeat, 1>{}([&](auto k0) {
485 static_for<0, KGroup, 1>{}([&](auto kg0) {
486 a_thread_copy_.Run(
488 make_tuple(Number<(m0 + 2) % MRepeat>{},
489 I0,
490 I0,
492 I0,
493 I0),
494 a_block_buf.At(local_read_buf),
497 Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
498 2>{},
499 I0,
500 I0,
501 k0,
502 I0,
504 a_thread_buf);
505 });
506 });
507 }
508 else
509 {
510 static_for<0, KRepeat, 1>{}([&](auto k0) {
511 static_for<0, KGroup, 1>{}([&](auto kg0) {
512 a_thread_copy_.Run(
514 make_tuple(Number<(m0 + 2) % MRepeat>{},
515 I0,
516 I0,
518 I0,
519 I0),
520 a_block_buf.At(mfma_reg_buf),
523 Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
524 2>{},
525 I0,
526 I0,
527 k0,
528 I0,
530 a_thread_buf);
531 });
532 });
533 }
534 });
536 };
537
538 LoopFunc(I0, I1);
539 LoopFunc(I1, I0);
540
541 i += 2;
542 } while(i < (num_loop - 2));
543 }
544 // tail
545 if constexpr(TailNum == TailNumber::Even)
546 {
547 b_blockwise_copy.Run(b_grid_desc,
548 b_grid_buf,
550 b_block_origin_idx,
551 b_thread_bufs(I1));
552 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
553
554 static_for<0, MRepeat, 1>{}([&](auto m0) {
555 static_for<0, KRepeat, 1>{}([&](auto k0) {
556 static_for<0, NRepeat, 1>{}([&](auto n0) {
559
560 static_for<0, KPack, 1>{}([&](auto ik) {
561 a_thread_vec.template AsType<ComputeDataType>()(ik) =
562 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
563 make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
564 b_thread_vec.template AsType<ComputeDataType>()(ik) =
565 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
566 make_tuple(n0, I0, k0, ik))>{}];
567 });
568
569 using mfma_input_type =
570 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
571
572 constexpr index_t c_offset =
573 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
574
575 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
576 b_thread_vec.template AsType<mfma_input_type>(),
577 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
578 });
579 });
580
581 if constexpr(m0.value == (MRepeat - 2))
582 {
584
585 static_for<0, KRepeat, 1>{}([&](auto k0) {
586 static_for<0, KGroup, 1>{}([&](auto kg0) {
587 a_thread_copy_.Run(
591 a_block_buf.At(I1),
594 Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
595 a_thread_buf);
596 });
597 });
598 }
599 else if constexpr(m0.value == (MRepeat - 1))
600 {
601 static_for<0, KRepeat, 1>{}([&](auto k0) {
602 static_for<0, KGroup, 1>{}([&](auto kg0) {
603 a_thread_copy_.Run(
605 make_tuple(Number<(m0 + 2) % MRepeat>{},
606 I0,
607 I0,
609 I0,
610 I0),
611 a_block_buf.At(I1),
614 Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
615 a_thread_buf);
616 });
617 });
618 }
619 else
620 {
621 static_for<0, KRepeat, 1>{}([&](auto k0) {
622 static_for<0, KGroup, 1>{}([&](auto kg0) {
623 a_thread_copy_.Run(
625 make_tuple(Number<(m0 + 2) % MRepeat>{},
626 I0,
627 I0,
629 I0,
630 I0),
631 a_block_buf.At(I0),
634 Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
635 a_thread_buf);
636 });
637 });
638 }
639 });
640
642
643 static_for<0, MRepeat, 1>{}([&](auto m0) {
644 static_for<0, KRepeat, 1>{}([&](auto k0) {
645 static_for<0, NRepeat, 1>{}([&](auto n0) {
648
649 static_for<0, KPack, 1>{}([&](auto ik) {
650 a_thread_vec.template AsType<ComputeDataType>()(ik) =
651 a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
652 (m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}];
653 b_thread_vec.template AsType<ComputeDataType>()(ik) =
654 b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
655 make_tuple(n0, I0, k0, ik))>{}];
656 });
657
658 using mfma_input_type =
659 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
660
661 constexpr index_t c_offset =
662 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
663
664 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
665 b_thread_vec.template AsType<mfma_input_type>(),
666 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
667 });
668 });
669
670 if constexpr(m0.value < (MRepeat - 2))
671 {
672 static_for<0, KRepeat, 1>{}([&](auto k0) {
673 static_for<0, KGroup, 1>{}([&](auto kg0) {
674 a_thread_copy_.Run(
678 a_block_buf.At(I1),
680 make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
681 I0,
682 I0,
683 k0,
684 I0,
686 a_thread_buf);
687 });
688 });
689 }
690 });
691
693 }
694 else if constexpr(TailNum == TailNumber::Odd)
695 {
696 static_for<0, MRepeat, 1>{}([&](auto m0) {
697 static_for<0, KRepeat, 1>{}([&](auto k0) {
698 static_for<0, NRepeat, 1>{}([&](auto n0) {
701
702 static_for<0, KPack, 1>{}([&](auto ik) {
703 a_thread_vec.template AsType<ComputeDataType>()(ik) =
704 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
705 make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
706 b_thread_vec.template AsType<ComputeDataType>()(ik) =
707 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
708 make_tuple(n0, I0, k0, ik))>{}];
709 });
710
711 using mfma_input_type =
712 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
713
714 constexpr index_t c_offset =
715 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
716
717 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
718 b_thread_vec.template AsType<mfma_input_type>(),
719 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
720 });
721 });
722
723 if constexpr(m0.value < (MRepeat - 2))
724 {
725 static_for<0, KRepeat, 1>{}([&](auto k0) {
726 static_for<0, KGroup, 1>{}([&](auto kg0) {
727 a_thread_copy_.Run(
731 a_block_buf.At(I0),
734 Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
735 a_thread_buf);
736 });
737 });
738 }
739 });
740 }
741 }
742
743 protected:
744 // MRepeat MWave MLane KRepeat KLane KPack
745 // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
746 // Reduce the vgpr usage here.
749
751 ComputeDataType,
753 decltype(a_thread_desc_),
754 Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
756 5,
757 A_K1,
758 A_K1>;
759
761
764
765 static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
766
768};
769
770} // namespace ck
#define DS_READ_A_PREFETCH_STAGES
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp:8
constexpr auto compute_stage_loads(T total_loads, T stages)
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp:11
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ SCHED_GROUP_LDS_READ
Definition blkgemmpipe_scheduler.hpp:56
@ SCHED_GROUP_MFMA
Definition blkgemmpipe_scheduler.hpp:54
@ SCHED_GROUP_LDS_WRITE
Definition blkgemmpipe_scheduler.hpp:57
@ SCHED_GROUP_VMEM
Definition blkgemmpipe_scheduler.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
#define INT32_MAX
Definition stdint.h:182
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_pipeline_base.
Definition blockwise_gemm_pipeline_xdlops_base.hpp:222
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:280
static constexpr index_t MWaves
Definition blockwise_gemm_pipeline_xdlops_base.hpp:44
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:239
static constexpr auto c_thread_desc_
Definition blockwise_gemm_pipeline_xdlops_base.hpp:378
static constexpr auto xdlops_gemm
Definition blockwise_gemm_pipeline_xdlops_base.hpp:54
static constexpr index_t KGroup
Definition blockwise_gemm_pipeline_xdlops_base.hpp:67
static constexpr auto I1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:37
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:266
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:294
static __device__ auto CalculateAThreadOriginDataIndex6D()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:136
static constexpr index_t WaveSize
Definition blockwise_gemm_pipeline_xdlops_base.hpp:46
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:253
static constexpr index_t B_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:51
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops_base.hpp:82
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:111
static constexpr auto I0
Definition blockwise_gemm_pipeline_xdlops_base.hpp:36
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:160
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:189
static constexpr index_t KRepeat
Definition blockwise_gemm_pipeline_xdlops_base.hpp:64
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:359
static constexpr auto I2
Definition blockwise_gemm_pipeline_xdlops_base.hpp:38
static constexpr index_t A_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:50
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:341
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:307
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:324
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp:112
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp:337
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataType, decltype(a_block_desc_m0_m1_m2_k0_k1_k2), decltype(a_thread_desc_), Sequence< 1, 1, 1, 1, 1, KPack/KGroup >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp:750
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp:47
Definition utility/sequence.hpp:43
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition functional2.hpp:33
Definition dtype_vector.hpp:10