blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp Source File

blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp Source File
blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Compute optimized pipeline
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MPerXDL,
32 index_t NPerXDL,
33 index_t MRepeat,
34 index_t NRepeat,
35 index_t KPacks>
39
40template <index_t BlockSize,
41 typename ADataType,
42 typename BDataType,
43 typename ComputeDataType,
44 typename AccDataType,
45 typename ATileDesc,
46 typename BTileDesc,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
51 index_t MPerBlock,
52 index_t NPerBlock,
53 index_t KPerBlock,
54 index_t MPerXDL,
55 index_t NPerXDL,
56 index_t MRepeat,
57 index_t NRepeat,
58 index_t KPack
59 // ,bool TransposeC //disable transposec right now...
60 >
62 BlockSize,
63 ADataType,
64 BDataType,
65 ComputeDataType,
66 AccDataType,
67 ATileDesc,
68 BTileDesc,
69 AMmaTileDesc,
70 BMmaTileDesc,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
73 MPerBlock,
74 NPerBlock,
75 KPerBlock,
76 MPerXDL,
77 NPerXDL,
78 MRepeat,
79 NRepeat,
80 KPack>
82 ADataType,
83 BDataType,
84 ComputeDataType,
85 AccDataType,
86 ATileDesc,
87 BTileDesc,
88 AMmaTileDesc,
89 BMmaTileDesc,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
92 MPerBlock,
93 NPerBlock,
94 KPerBlock,
95 MPerXDL,
96 NPerXDL,
97 MRepeat,
98 NRepeat,
99 KPack>
100
101{
103 ADataType,
104 BDataType,
105 ComputeDataType,
106 AccDataType,
107 ATileDesc,
108 BTileDesc,
109 AMmaTileDesc,
110 BMmaTileDesc,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
113 MPerBlock,
114 NPerBlock,
115 KPerBlock,
116 MPerXDL,
117 NPerXDL,
118 MRepeat,
119 NRepeat,
120 KPack>;
121 using Base::A_K1;
122 using Base::B_K1;
123 using Base::I0;
124 using Base::I1;
125 using Base::I2;
126 using Base::KRepeat;
127 using typename Base::HotLoopInstList;
128
141
142 using Base::AMmaKStride;
143 using Base::BMmaKStride;
144
145 using Base::MWaves;
146 using Base::WaveSize;
147
150
151 static constexpr index_t PrefetchStages = 2;
152 static constexpr index_t PrefillStages = 1;
153 static constexpr index_t GlobalBufferNum = 1;
154 static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
155
156 template <typename TileDesc_M0_M1_M2_K>
157 __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
158 {
159 constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
160 constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
161 constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
162 constexpr index_t K2 = KPack;
163 constexpr index_t K1 = WaveSize / NPerXDL;
164 constexpr index_t K0 = KRepeat;
165
167 TileDesc_M0_M1_M2_K{},
175 }
176
177 static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
179
180 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
181 {
182 return num_loop > PrefetchStages;
183 }
184
185 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
186 {
187 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
188 }
189
190 template <typename Stage>
191 __device__ static constexpr auto HotLoopScheduler(Stage stage)
192 {
193 constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
194 constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
195 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
196 constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num;
197
198 constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num;
199
200 constexpr auto staged_num_ds_read_inst_a =
201 ck::math::integer_divide_ceil(num_ds_read_inst_a, MRepeat);
202 constexpr auto staged_num_mfma = ck::math::integer_divide_ceil(num_mfma, MRepeat);
203
204 constexpr auto staged_num_mfma_per_ds_read_a =
205 ck::math::integer_divide_ceil(staged_num_mfma, staged_num_ds_read_inst_a);
206
207 if constexpr(stage.value == 0)
208 {
209 constexpr auto staged_num_buffer_load_b_per_ds_read_a =
210 ck::math::integer_divide_ceil(num_buffer_load_inst_b, staged_num_ds_read_inst_a);
211 constexpr auto staged_num_mfma_per_buffer_load_b =
212 ck::math::integer_divide_ceil(staged_num_mfma, num_buffer_load_inst_b);
213 // B global
215 ignore = i_inst;
216
217 static_for<0, staged_num_buffer_load_b_per_ds_read_a - 1, 1>{}([&](auto ibuf_inst) {
218 ignore = ibuf_inst;
219 __builtin_amdgcn_sched_group_barrier(
220 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA
221 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
222 });
223
224 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
225 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
226 __builtin_amdgcn_sched_group_barrier(
227 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA
228 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
229 });
230
231 __builtin_amdgcn_sched_barrier(0);
232 }
233 else if constexpr(stage.value == 1)
234 {
235 constexpr auto staged_num_mfma_per_ds_write_a =
236 math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a);
237
238 constexpr auto stage_more_mfma =
239 staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a;
240
241 // A local write
242 static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) {
243 if constexpr(i_inst.value < stage_more_mfma)
244 {
245 if(i_inst.value < staged_num_ds_read_inst_a)
246 {
247 __builtin_amdgcn_sched_group_barrier(
248 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
249 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
250 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
251 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
252 }
253 else
254 {
255 __builtin_amdgcn_sched_group_barrier(
256 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA
257 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
258 }
259 }
260 else
261 {
262 if(i_inst.value < staged_num_ds_read_inst_a)
263 {
264 __builtin_amdgcn_sched_group_barrier(
265 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA
266 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
267 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
268 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
269 }
270 else
271 {
272 __builtin_amdgcn_sched_group_barrier(
273 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
274 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
275 }
276 }
277 });
278
279 __builtin_amdgcn_sched_barrier(0);
280 }
281 else if constexpr(stage.value == 2)
282 {
283 constexpr auto staged_num_mfma_per_buffer_load_a =
284 math::integer_divide_ceil(staged_num_mfma, num_buffer_load_inst_a);
285
286 constexpr auto stage_more_mfma =
287 staged_num_mfma - (staged_num_mfma_per_buffer_load_a - 1) * num_buffer_load_inst_a;
288
289 // A global
291 if constexpr(i_inst.value < stage_more_mfma)
292 {
293 if(i_inst.value < staged_num_ds_read_inst_a)
294 {
295 __builtin_amdgcn_sched_group_barrier(
296 0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA
297 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
298 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
299 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
300 }
301 else
302 {
303 __builtin_amdgcn_sched_group_barrier(
304 0x008, staged_num_mfma_per_buffer_load_a, 0); // MFMA
305 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
306 }
307 }
308 else
309 {
310 if(i_inst.value < staged_num_ds_read_inst_a)
311 {
312 __builtin_amdgcn_sched_group_barrier(
313 0x008, staged_num_mfma_per_buffer_load_a - 2, 0); // MFMA
314 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
315 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
316 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
317 }
318 else
319 {
320 __builtin_amdgcn_sched_group_barrier(
321 0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA
322 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
323 }
324 }
325 });
326
327 __builtin_amdgcn_sched_barrier(0);
328 }
329 else
330 {
331 // A local Read
333 ignore = i_inst;
334 __builtin_amdgcn_sched_group_barrier(
335 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
336 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
337 });
338
339 __builtin_amdgcn_sched_barrier(0);
340 }
341 }
342
343 template <typename Stage>
344 __device__ static constexpr auto EpilogueScheduler_1(Stage stage)
345 {
346 constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
347 constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
348 constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num;
349
350 constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num;
351
352 constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
353 constexpr auto staged_num_mfma = num_mfma / MRepeat;
354
355 constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
356
357 if constexpr(stage.value == 0)
358 {
359 constexpr auto staged_num_buffer_load_b_per_ds_read_a =
360 num_buffer_load_inst_b / staged_num_ds_read_inst_a;
361 constexpr auto staged_num_mfma_per_buffer_load_b =
362 staged_num_mfma / num_buffer_load_inst_b;
363 // B global
365 ignore = i_inst;
366
368 ignore = ibuf_inst;
369 __builtin_amdgcn_sched_group_barrier(
370 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA
371 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
372 });
373
374 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
375 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
376 __builtin_amdgcn_sched_group_barrier(
377 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA
378 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
379 });
380
381 __builtin_amdgcn_sched_barrier(0);
382 }
383 else if constexpr(stage.value == 1)
384 {
385#if 0
386 constexpr auto staged_num_ds_write_a_per_ds_read_a =
387 num_ds_write_inst_a / staged_num_ds_read_inst_a;
388 constexpr auto staged_num_mfma_per_ds_write_a = staged_num_mfma / num_ds_write_inst_a;
389 // A local write
391 ignore = i_inst;
392
394 ignore = idswrite_inst;
395 __builtin_amdgcn_sched_group_barrier(
396 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
397 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
398 });
399
400 __builtin_amdgcn_sched_group_barrier(
401 0x008, staged_num_ds_write_a_per_ds_read_a, 0); // MFMA
402 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
403 });
404#elif 1
405 constexpr auto staged_num_mfma_per_ds_write_a =
406 math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a);
407
408 constexpr auto stage_more_mfma =
409 staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a;
410
411 // A local write
412 static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) {
413 if constexpr(i_inst.value < stage_more_mfma)
414 {
415 if(i_inst.value < staged_num_ds_read_inst_a)
416 {
417 __builtin_amdgcn_sched_group_barrier(
418 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
419 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
420 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
421 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
422 }
423 else
424 {
425 __builtin_amdgcn_sched_group_barrier(
426 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA
427 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
428 }
429 }
430 else
431 {
432 if(i_inst.value < staged_num_ds_read_inst_a)
433 {
434 __builtin_amdgcn_sched_group_barrier(
435 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA
436 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
437 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
438 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
439 }
440 else
441 {
442 __builtin_amdgcn_sched_group_barrier(
443 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
444 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
445 }
446 }
447 });
448#endif
449 __builtin_amdgcn_sched_barrier(0);
450 }
451 else
452 {
453 // A local Read
455 ignore = i_inst;
456 __builtin_amdgcn_sched_group_barrier(
457 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
458 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
459 });
460
461 __builtin_amdgcn_sched_barrier(0);
462 }
463 }
464
465 __device__ static constexpr auto EpilogueScheduler_2()
466 {
467 constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
468
469 constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num;
470
471 constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
472 constexpr auto staged_num_mfma = num_mfma / MRepeat;
473
474 constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
475
476 // A local Read
478 ignore = i_inst;
479 __builtin_amdgcn_sched_group_barrier(0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
480 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
481 });
482
483 __builtin_amdgcn_sched_barrier(0);
484 }
485
486 template <bool HasMainLoop,
487 TailNumber TailNum,
488 typename AGridDesc,
489 typename ABlockDesc,
490 typename ABlockTransfer,
491 typename AGridBuffer,
492 typename ABlockBuffer,
493 typename ABlockTransferStep,
494 typename BGridDesc,
495 typename BBlockTransfer,
496 typename BGridBuffer,
497 typename BBlockBuffer,
498 typename BBlockTransferStep,
499 typename CThreadBuffer>
500 __device__ void Run(const AGridDesc& a_grid_desc,
501 const ABlockDesc& a_block_desc,
502 ABlockTransfer& a_blockwise_copy,
503 const AGridBuffer& a_grid_buf,
504 ABlockBuffer& a_block_buf,
505 const ABlockTransferStep& a_block_copy_step,
506 const BGridDesc& b_grid_desc,
507 BBlockTransfer& b_blockwise_copy,
508 const BGridBuffer& b_grid_buf,
509 BBlockBuffer& b_block_buf,
510 const BBlockTransferStep& b_block_copy_step,
511 CThreadBuffer& c_thread_buf,
512 index_t num_loop) const
513 {
514 ignore = b_block_buf;
515 __builtin_amdgcn_sched_barrier(0);
517 a_thread_desc_.GetElementSpaceSize());
519 b_thread_desc_.GetElementSpaceSize());
521 b_thread_desc_.GetElementSpaceSize());
522
523 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
524 StaticallyIndexedArray<decltype(b_thread_dequant_buf), Number<2>{}> b_thread_dequant_bufs;
525 constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
526
527 // Global prefetch A1 B1
528 b_blockwise_copy.Run(b_grid_desc,
529 b_grid_buf,
531 b_block_origin_idx,
532 b_thread_bufs(I0));
533 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
534
535 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
536 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
537 __builtin_amdgcn_sched_barrier(0);
538
539 // // Local prefill A1
540 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
541
542 // // Global prefetch A2
543 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
544 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
545
546 // Local prefetch A1
548 static_for<0, KRepeat, 1>{}([&](auto k0) {
550 make_tuple(I0, I0, I0, k0, I0, I0),
551 a_block_buf.At(I0),
553 make_tuple(I0, I0, I0, k0, I0, I0),
554 a_thread_buf);
555 });
556 // B VGPR->VGPR dequant
558 b_block_origin_idx,
559 b_thread_bufs(I0),
561 make_tuple(I0, I0, I0, I0),
562 b_thread_dequant_bufs(I0));
563
564 // Initialize C
565 c_thread_buf.Clear();
566
567 __builtin_amdgcn_sched_barrier(0);
568
569 // main body
570 if constexpr(HasMainLoop)
571 {
572 index_t i = 0;
573 do
574 {
575 auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
576 static_for<0, MRepeat, 1>{}([&](auto m0) {
577 if constexpr(m0.value == 0)
578 {
579 b_blockwise_copy.Run(b_grid_desc,
580 b_grid_buf,
582 b_block_origin_idx,
583 b_thread_bufs(local_read_buf));
584 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
585 }
586 else if constexpr(m0.value == 1)
587 {
588 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf));
589 }
590 else if constexpr(m0.value == 2)
591 {
592 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
593 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
594 }
595
596 static_for<0, KRepeat, 1>{}([&](auto k0) {
597 static_for<0, NRepeat, 1>{}([&](auto n0) {
600
601 static_for<0, KPack, 1>{}([&](auto ik) {
602 a_thread_vec.template AsType<ComputeDataType>()(ik) =
603 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
604 make_tuple((m0 + HotloopLocalBufSwitch * mfma_reg_buf) %
605 2,
606 I0,
607 I0,
608 k0,
609 I0,
610 ik))>{}];
611 b_thread_vec.template AsType<ComputeDataType>()(ik) =
612 b_thread_dequant_bufs[mfma_reg_buf]
613 [Number<b_thread_desc_.CalculateOffset(
614 make_tuple(n0, I0, k0, ik))>{}];
615 });
616
617 using mfma_input_type =
618 typename vector_type<ComputeDataType,
619 xdlops_gemm.K1PerXdlops>::type;
620
621 constexpr index_t c_offset =
622 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
623
624 xdlops_gemm.Run(
625 a_thread_vec.template AsType<mfma_input_type>(),
626 b_thread_vec.template AsType<mfma_input_type>(),
627 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
628 });
629 });
630
631 if constexpr(m0.value == MRepeat - 1)
632 {
634
635 static_for<0, KRepeat, 1>{}([&](auto k0) {
636 a_thread_copy_.Run(
638 make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
639 a_block_buf.At(local_read_buf),
642 Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) %
643 2>{},
644 I0,
645 I0,
646 k0,
647 I0,
648 I0),
649 a_thread_buf);
650 });
651 // B VGPR->VGPR dequant
653 b_block_origin_idx,
654 b_thread_bufs(local_read_buf),
656 make_tuple(I0, I0, I0, I0),
657 b_thread_dequant_bufs(local_read_buf));
658 }
659 else
660 {
661 static_for<0, KRepeat, 1>{}([&](auto k0) {
662 a_thread_copy_.Run(
664 make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
665 a_block_buf.At(mfma_reg_buf),
668 Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) %
669 2>{},
670 I0,
671 I0,
672 k0,
673 I0,
674 I0),
675 a_thread_buf);
676 });
677 // B VGPR->VGPR dequant
679 b_block_origin_idx,
680 b_thread_bufs(mfma_reg_buf),
682 make_tuple(I0, I0, I0, I0),
683 b_thread_dequant_bufs(mfma_reg_buf));
684 }
685
687 });
688 };
689
690 LoopFunc(I0, I1);
691 LoopFunc(I1, I0);
692
693 i += 2;
694 } while(i < (num_loop - 2));
695 }
696 // tail
697 if constexpr(TailNum == TailNumber::Even)
698 {
699 static_for<0, MRepeat, 1>{}([&](auto m0) {
700 if constexpr(m0.value == 0)
701 {
702 b_blockwise_copy.Run(b_grid_desc,
703 b_grid_buf,
705 b_block_origin_idx,
706 b_thread_bufs(I1));
707 }
708 else if constexpr(m0.value == MRepeat - 1)
709 {
710 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
711 }
712
713 static_for<0, KRepeat, 1>{}([&](auto k0) {
714 static_for<0, NRepeat, 1>{}([&](auto n0) {
717
718 static_for<0, KPack, 1>{}([&](auto ik) {
719 a_thread_vec.template AsType<ComputeDataType>()(ik) =
720 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
721 make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
722 b_thread_vec.template AsType<ComputeDataType>()(ik) =
723 b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
724 make_tuple(n0, I0, k0, ik))>{}];
725 });
726
727 using mfma_input_type =
728 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
729
730 constexpr index_t c_offset =
731 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
732
733 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
734 b_thread_vec.template AsType<mfma_input_type>(),
735 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
736 });
737 });
738
739 if constexpr(m0.value == MRepeat - 1)
740 {
742
743 static_for<0, KRepeat, 1>{}([&](auto k0) {
744 a_thread_copy_.Run(
746 make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
747 a_block_buf.At(I1),
749 make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
750 a_thread_buf);
751 });
752 // B VGPR->VGPR dequant
754 b_block_origin_idx,
755 b_thread_bufs(I1),
757 make_tuple(I0, I0, I0, I0),
758 b_thread_dequant_bufs(I1));
759 }
760 else
761 {
762 static_for<0, KRepeat, 1>{}([&](auto k0) {
763 a_thread_copy_.Run(
765 make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
766 a_block_buf.At(I0),
768 make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
769 a_thread_buf);
770 });
771 // B VGPR->VGPR dequant
773 b_block_origin_idx,
774 b_thread_bufs(I0),
776 make_tuple(I0, I0, I0, I0),
777 b_thread_dequant_bufs(I0));
778 }
779
781 });
782
783 static_for<0, MRepeat, 1>{}([&](auto m0) {
784 static_for<0, KRepeat, 1>{}([&](auto k0) {
785 static_for<0, NRepeat, 1>{}([&](auto n0) {
788
789 static_for<0, KPack, 1>{}([&](auto ik) {
790 a_thread_vec.template AsType<ComputeDataType>()(ik) =
791 a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
792 (m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}];
793 b_thread_vec.template AsType<ComputeDataType>()(ik) =
794 b_thread_dequant_bufs[I1][Number<b_thread_desc_.CalculateOffset(
795 make_tuple(n0, I0, k0, ik))>{}];
796 });
797
798 using mfma_input_type =
799 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
800
801 constexpr index_t c_offset =
802 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
803
804 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
805 b_thread_vec.template AsType<mfma_input_type>(),
806 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
807 });
808 });
809
810 if constexpr(m0.value != (MRepeat - 1))
811 {
812 static_for<0, KRepeat, 1>{}([&](auto k0) {
813 a_thread_copy_.Run(
815 make_tuple(Number<m0 + 1>{}, I0, I0, k0, I0, I0),
816 a_block_buf.At(I1),
819 Number<(m0 + 1 + HotloopLocalBufSwitch) % 2>{}, I0, I0, k0, I0, I0),
820 a_thread_buf);
821 });
822 // B VGPR->VGPR dequant
824 b_block_origin_idx,
825 b_thread_bufs(I1),
827 make_tuple(I0, I0, I0, I0),
828 b_thread_dequant_bufs(I1));
829
831 }
832 });
833 // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
834 // latency
835 // __builtin_amdgcn_sched_barrier(0);
836 }
837 else
838 {
839 static_for<0, MRepeat, 1>{}([&](auto m0) {
840 static_for<0, KRepeat, 1>{}([&](auto k0) {
841 static_for<0, NRepeat, 1>{}([&](auto n0) {
844
845 static_for<0, KPack, 1>{}([&](auto ik) {
846 a_thread_vec.template AsType<ComputeDataType>()(ik) =
847 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
848 make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
849 b_thread_vec.template AsType<ComputeDataType>()(ik) =
850 b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
851 make_tuple(n0, I0, k0, ik))>{}];
852 });
853
854 using mfma_input_type =
855 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
856
857 constexpr index_t c_offset =
858 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
859
860 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
861 b_thread_vec.template AsType<mfma_input_type>(),
862 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
863 });
864 });
865
866 if constexpr(m0.value != (MRepeat - 1))
867 {
868 static_for<0, KRepeat, 1>{}([&](auto k0) {
870 make_tuple(Number<m0 + 1>{}, I0, I0, k0, I0, I0),
871 a_block_buf.At(I0),
873 make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
874 a_thread_buf);
875 });
876 // B VGPR->VGPR dequant
878 b_block_origin_idx,
879 b_thread_bufs(I0),
881 make_tuple(I0, I0, I0, I0),
882 b_thread_dequant_bufs(I0));
883
885 }
886 });
887 }
888 }
889
890 protected:
891 // MRepeat MWave MLane KRepeat KLane KPack
892 // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
893 // Reduce the vgpr usage here.
896
898 ComputeDataType,
900 decltype(a_thread_desc_),
903 5,
904 A_K1,
905 A_K1>;
906
908
911
912 static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
913
915
917
925 Sequence<1, 2, 0, 3>,
926 3,
927 KPack>;
928
931};
932
933} // namespace ck
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_pipeline_base.
Definition blockwise_gemm_pipeline_xdlops_base.hpp:222
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:280
static constexpr index_t MWaves
Definition blockwise_gemm_pipeline_xdlops_base.hpp:44
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:239
static constexpr auto c_thread_desc_
Definition blockwise_gemm_pipeline_xdlops_base.hpp:378
static constexpr auto I1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:37
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:266
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:294
static constexpr index_t AMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:60
static __device__ auto CalculateAThreadOriginDataIndex6D()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:136
static constexpr index_t WaveSize
Definition blockwise_gemm_pipeline_xdlops_base.hpp:46
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:253
static constexpr index_t B_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:51
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops_base.hpp:82
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:111
static constexpr auto I0
Definition blockwise_gemm_pipeline_xdlops_base.hpp:36
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:160
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:189
static constexpr index_t KRepeat
Definition blockwise_gemm_pipeline_xdlops_base.hpp:64
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:359
static constexpr auto I2
Definition blockwise_gemm_pipeline_xdlops_base.hpp:38
static constexpr index_t A_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:50
static constexpr index_t BMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:61
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:341
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:307
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:324
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataType, decltype(a_block_desc_m0_m1_m2_k0_k1_k2), decltype(a_thread_desc_), Sequence< 1, 1, 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp:897
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp:102
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp:500
ThreadwiseTensorSliceTransfer_StaticToStatic< BDataType, ComputeDataType, decltype(b_block_desc_n0_n1_k0_k1), decltype(b_block_desc_n0_n1_k0_k1), tensor_operation::element_wise::PassThrough, Sequence< Number< NRepeat >{}, I1, Number< KRepeat >{}, Number< KPack >{}>, Sequence< 1, 2, 0, 3 >, 3, KPack > BThreadDequantCopy
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp:918
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp:37
Definition utility/sequence.hpp:43
Threadwise data transfer.
Definition threadwise_tensor_slice_transfer.hpp:1720
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition xdlops_gemm.hpp:1821
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
Definition dtype_vector.hpp:10