blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp Source File

blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp Source File
blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Compute optimized pipeline
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MScaleBlock,
32 index_t NScaleBlock,
33 index_t KScaleBlock,
34 index_t MPerXDL,
35 index_t NPerXDL,
36 index_t MRepeat,
37 index_t NRepeat,
38 index_t KPacks>
42
43template <index_t BlockSize,
44 typename ADataType,
45 typename BDataType,
46 typename ComputeDataType,
47 typename AccDataType,
48 typename ATileDesc,
49 typename BTileDesc,
50 typename AMmaTileDesc,
51 typename BMmaTileDesc,
52 index_t ABlockTransferSrcScalarPerVector,
53 index_t BBlockTransferSrcScalarPerVector,
54 index_t MPerBlock,
55 index_t NPerBlock,
56 index_t KPerBlock,
57 index_t MScaleBlock,
58 index_t NScaleBlock,
59 index_t KScaleBlock,
60 index_t MPerXDL,
61 index_t NPerXDL,
62 index_t MRepeat,
63 index_t NRepeat,
64 index_t KPack
65 // ,bool TransposeC //disable transposec right now...
66 >
69 BlockSize,
70 ADataType,
71 BDataType,
72 ComputeDataType,
73 AccDataType,
74 ATileDesc,
75 BTileDesc,
76 AMmaTileDesc,
77 BMmaTileDesc,
78 ABlockTransferSrcScalarPerVector,
79 BBlockTransferSrcScalarPerVector,
80 MPerBlock,
81 NPerBlock,
82 KPerBlock,
83 MScaleBlock,
84 NScaleBlock,
85 KScaleBlock,
86 MPerXDL,
87 NPerXDL,
88 MRepeat,
89 NRepeat,
90 KPack> : BlockwiseGemmXdlops_pipeline_base<BlockSize,
91 ADataType,
92 BDataType,
93 ComputeDataType,
94 AccDataType,
95 ATileDesc,
96 BTileDesc,
97 AMmaTileDesc,
98 BMmaTileDesc,
99 ABlockTransferSrcScalarPerVector,
100 BBlockTransferSrcScalarPerVector,
101 MPerBlock,
102 NPerBlock,
103 KPerBlock,
104 MPerXDL,
105 NPerXDL,
106 MRepeat,
107 NRepeat,
108 KPack,
109 true>
110
111{
113 ADataType,
114 BDataType,
115 ComputeDataType,
116 AccDataType,
117 ATileDesc,
118 BTileDesc,
119 AMmaTileDesc,
120 BMmaTileDesc,
121 ABlockTransferSrcScalarPerVector,
122 BBlockTransferSrcScalarPerVector,
123 MPerBlock,
124 NPerBlock,
125 KPerBlock,
126 MPerXDL,
127 NPerXDL,
128 MRepeat,
129 NRepeat,
130 KPack,
131 true>;
132 using Base::A_K1;
133 using Base::B_K1;
134 using Base::I0;
135 using Base::I1;
136 using Base::I2;
137 using Base::KGroup;
138 using Base::KRepeat;
139 using Base::xdlops_gemm;
140 using typename Base::HotLoopInstList;
141
154 using Base::MWaves;
155 using Base::WaveSize;
156
157 static constexpr index_t PrefetchStages = 2;
158 static constexpr index_t PrefillStages = 1;
159 static constexpr index_t GlobalBufferNum = 1;
160 static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
161
162 template <typename TileDesc_M0_M1_M2_K>
163 __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
164 {
165 constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
166 constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
167 constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
168 constexpr index_t K2 = KPack / KGroup;
169 constexpr index_t K1 = WaveSize / NPerXDL;
170 constexpr index_t K0 = KRepeat * KGroup;
171
173 TileDesc_M0_M1_M2_K{},
181 }
182
183 static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
185
186 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
187 {
188 return num_loop > PrefetchStages;
189 }
190
191 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
192 {
193 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
194 }
195
196 __device__ static constexpr auto HotLoopScheduler()
197 {
198 // A/B split schedule
199 // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
200 constexpr auto num_ds_read_inst_a =
201 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
204
205 constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
206
207 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
208 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
209
210 static_assert(num_buffer_load_inst_a == num_ds_write_inst_a);
211
212 constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
213 constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
214
215 constexpr auto ds_read_a_issue_cycle =
216 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
217 constexpr auto ds_read_a_mfma_rate =
218 math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle);
219
220 // constexpr auto num_dsread_a_mfma =
221 // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
222
223 constexpr auto num_total_stages = MRepeat;
224
225 // Group num_mfma_perstage num_ds_read_a_perstage
226 // since we want to reuse a local register buffer
227 constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
228 constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
229
230 constexpr auto num_ds_read_a_mfma_perstage =
231 math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
232
233 constexpr auto num_ds_read_a_prefetch_stages = 2;
234
235 constexpr auto buffer_load_perstage_more = math::integer_divide_ceil(
236 (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
237 constexpr auto buffer_load_perstage_less = math::integer_divide_floor(
238 (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
239
240 constexpr auto buffer_load_stages_more =
241 (num_buffer_load_inst_a + num_buffer_load_inst_b) -
242 math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b),
243 (num_total_stages - 2)) *
244 ((num_total_stages - 2));
245
246 constexpr auto buffer_load_b_stages =
247 buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b
248 ? num_buffer_load_inst_b / buffer_load_perstage_more
249 : (buffer_load_stages_more +
250 (num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) /
251 buffer_load_perstage_less);
252
253 constexpr auto buffer_load_a_stages =
254 num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages;
255
256 constexpr auto buffer_load_issue_point_b = 0;
257 constexpr auto buffer_load_issue_point_interval_more =
258 num_mfma_perstage / buffer_load_perstage_more
259 ? num_mfma_perstage / buffer_load_perstage_more
260 : 1;
261 constexpr auto buffer_load_issue_point_interval_less =
262 num_mfma_perstage / buffer_load_perstage_less
263 ? num_mfma_perstage / buffer_load_perstage_less
264 : 1;
265 constexpr auto ds_write_issue_point = 0;
266 constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
267
268 // B global read
270 // Scale load, 1B
271 if constexpr(i.value == 0)
272 {
273 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
274 }
275 // Scale load, 1A
276 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
277
278 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
279 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
280
281 if constexpr(((i < buffer_load_stages_more) &&
282 (imfma % buffer_load_issue_point_interval_more ==
283 buffer_load_issue_point_b)) ||
284 ((i >= buffer_load_stages_more) &&
285 (imfma % buffer_load_issue_point_interval_less ==
286 buffer_load_issue_point_b)))
287 {
288 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
289 }
290
291 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
292 {
293 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
294 }
295 // __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma
296 });
297 // __builtin_amdgcn_sched_barrier(0);
298 });
299
300 // A global read + A local write
302 // Scale load, 1A
303 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
304 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
305 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
306 if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
307 (imfma % buffer_load_issue_point_interval_more ==
308 ds_write_issue_point)) ||
309 (((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
310 (imfma % buffer_load_issue_point_interval_less ==
311 ds_write_issue_point)))
312 {
313 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
314 }
315 if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
316 (imfma % buffer_load_issue_point_interval_more ==
317 buffer_load_issue_point_a)) ||
318 (((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
319 (imfma % buffer_load_issue_point_interval_less ==
320 buffer_load_issue_point_a)))
321 {
322 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
323 }
324 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
325 {
326 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
327 }
328 // __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma
329 });
330 // __builtin_amdgcn_sched_barrier(0);
331 });
332
333 // lds synchronization, prefetch next loop local A
335 ignore = i;
336 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
337 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
338 // Scale load, 1A
339 if constexpr(imfma == 0)
340 {
341 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
342 }
343
344 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
345 {
346 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
347 }
348 // __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma
349 });
350 // __builtin_amdgcn_sched_barrier(0);
351 });
352 }
353
354 template <bool HasMainLoop,
355 int NumKBlockPerScale,
356 TailNumber TailNum,
357 typename AGridDesc,
358 typename ABlockDesc,
359 typename ABlockTransfer,
360 typename AGridBuffer,
361 typename ABlockBuffer,
362 typename ABlockTransferStep,
363 typename BGridDesc,
364 typename BBlockDesc,
365 typename BBlockTransfer,
366 typename BGridBuffer,
367 typename BBlockBuffer,
368 typename BBlockTransferStep,
369 typename CScaleThreadDesc,
370 typename CThreadBuffer,
371 typename AScaleGridBuffer,
372 typename AScaleGridDesc,
373 typename AScaleThreadDesc,
374 typename AScaleThreadTransfer,
375 typename AScaleThreadTransferStep,
376 typename BScaleGridBuffer,
377 typename BScaleGridDesc,
378 typename BScaleThreadDesc,
379 typename BScaleThreadTransfer,
380 typename BScaleThreadTransferStep>
381 __device__ void Run(
382 // ABlockCopy
383 const AGridDesc& a_grid_desc,
384 const ABlockDesc& a_block_desc,
385 ABlockTransfer& a_blockwise_copy,
386 const AGridBuffer& a_grid_buf,
387 ABlockBuffer& a_block_buf,
388 const ABlockTransferStep& a_block_copy_step,
389 // BBlockCopy
390 const BGridDesc& b_grid_desc,
391 const BBlockDesc& b_block_desc,
392 BBlockTransfer& b_blockwise_copy,
393 const BGridBuffer& b_grid_buf,
394 BBlockBuffer& b_block_buf,
395 const BBlockTransferStep& b_block_copy_step,
396 // CThread
397 const CScaleThreadDesc& c_scale_thread_desc,
398 CThreadBuffer& c_thread_buf,
399 // AScaleThreadCopy
400 const AScaleGridDesc& a_scale_grid_desc,
401 const AScaleThreadDesc& a_scale_thread_desc,
402 AScaleThreadTransfer& a_scale_thread_copy,
403 const AScaleGridBuffer& a_scale_grid_buf,
404 const AScaleThreadTransferStep& a_scale_thread_copy_step,
405 // BScaleThreadCopy
406 const BScaleGridDesc& b_scale_grid_desc,
407 const BScaleThreadDesc& b_scale_thread_desc,
408 BScaleThreadTransfer& b_scale_thread_copy,
409 const BScaleGridBuffer& b_scale_grid_buf,
410 const BScaleThreadTransferStep& b_scale_thread_copy_step,
411 // num_loop
412 index_t num_loop) const
413 {
414 ignore = b_block_desc;
415 ignore = b_block_buf;
416 __builtin_amdgcn_sched_barrier(0);
417 static_assert(CScaleThreadDesc{}.GetLength(Number<0>{}) == 1,
418 "Pipeline v3 only support scaleblocksliceK=1");
419 static_assert(CScaleThreadDesc{}.GetLength(Number<2>{}) == 1,
420 "Pipeline v3 only support scaleblocksliceN=1");
421 // assume kperblock = scaleblockk
423 a_thread_desc_.GetElementSpaceSize());
425 b_thread_desc_.GetElementSpaceSize());
426
427 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
428 constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
429
431 a_scale_thread_desc.GetElementSpaceSize());
433 b_scale_thread_desc.GetElementSpaceSize());
435 c_scale_thread_desc.GetElementSpaceSize());
436
437 StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
438 StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
439 // StaticallyIndexedArray<decltype(c_scale_thread_buf), Number<2>{}> c_scale_thread_bufs;
440
441 // Global prefetch A1 B1, AScale1 BScale1
442 b_blockwise_copy.Run(b_grid_desc,
443 b_grid_buf,
445 b_block_origin_idx,
446 b_thread_bufs(I0));
447 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
448
449 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
450 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
451 __builtin_amdgcn_sched_barrier(0);
452
453 a_scale_thread_copy.Run(a_scale_grid_desc,
454 a_scale_grid_buf,
455 a_scale_thread_desc,
456 make_tuple(I0, I0),
457 a_scale_thread_bufs(I0));
458
459 if constexpr(NumKBlockPerScale == 1)
460 {
461 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
462 a_scale_thread_copy_step.At(Number<1>{}));
463 }
464 else
465 {
466 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
467 a_scale_thread_copy_step.At(Number<0>{}));
468 }
469
470 b_scale_thread_copy.Run(b_scale_grid_desc,
471 b_scale_grid_buf,
472 b_scale_thread_desc,
473 make_tuple(I0, I0),
474 b_scale_thread_bufs(I0));
475
476 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
477
478 static_for<0, MRepeat, 1>{}([&](auto m0) {
479 c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0];
480 });
481
482 // Local prefill A1
483 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
484
485 // Global prefetch A2, AScale2 BScale2
486 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
487 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
488
489 a_scale_thread_copy.Run(a_scale_grid_desc,
490 a_scale_grid_buf,
491 a_scale_thread_desc,
492 make_tuple(I0, I0),
493 a_scale_thread_bufs(I0));
494
495 if constexpr(NumKBlockPerScale == 1)
496 {
497 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
498 a_scale_thread_copy_step.At(Number<1>{}));
499 }
500 else
501 {
502 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
503 a_scale_thread_copy_step.At(Number<0>{}));
504 }
505
506 b_scale_thread_copy.Run(b_scale_grid_desc,
507 b_scale_grid_buf,
508 b_scale_thread_desc,
509 make_tuple(I0, I0),
510 b_scale_thread_bufs(I0));
511
512 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
513
514 // Initialize C
515 c_thread_buf.Clear();
516
517 // Double register buffer for non-scaled gemm computation
518 // 1. Reduce register pressure
519 // 2. Decouple the dependency between mfma instruction and scale-fma instruction following.
521 AccDataType,
522 1,
523 xdlops_gemm.GetRegSizePerXdlops(),
524 true>
525 c_thread_buf_per_scale;
526
527 // Local prefetch A1
529 static_for<0, 2, 1>{}([&](auto m0) {
530 static_for<0, KRepeat, 1>{}([&](auto k0) {
531 static_for<0, KGroup, 1>{}([&](auto kg0) {
534 a_block_buf.At(I0),
536 make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
537 a_thread_buf);
538 });
539 });
540 });
541
542#if 0
543 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
544 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
545 .template AsType<AccDataType>()(Number<t>{}) = 0;
546 });
547
548 // Fill first mfma buffer
549 static_for<0, KRepeat, 1>{}([&](auto k0) {
552
553 static_for<0, KPack, 1>{}([&](auto ik) {
554 a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_buf
555 [Number<a_thread_desc_.CalculateOffset(make_tuple(I0, I0, I0, k0, I0, ik))>{}];
556 b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_bufs
557 [I0][Number<b_thread_desc_.CalculateOffset(make_tuple(I0, I0, k0, ik))>{}];
558 });
559
560 using mfma_input_type =
561 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
562
563 xdlops_gemm.template Run<>(a_thread_vec.template AsType<mfma_input_type>(),
564 b_thread_vec.template AsType<mfma_input_type>(),
565 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
566 });
567#endif
568 __builtin_amdgcn_sched_barrier(0);
569
570 // main body
571 if constexpr(HasMainLoop)
572 {
573 index_t i = 0;
574 do
575 {
576 auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
577 b_blockwise_copy.Run(b_grid_desc,
578 b_grid_buf,
580 b_block_origin_idx,
581 b_thread_bufs(local_read_buf));
582 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
583
584 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf));
585 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
586 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
587
588 a_scale_thread_copy.Run(a_scale_grid_desc,
589 a_scale_grid_buf,
590 a_scale_thread_desc,
591 make_tuple(I0, I0),
592 a_scale_thread_bufs(local_read_buf));
593
594 if constexpr(NumKBlockPerScale == 1)
595 {
596 a_scale_thread_copy.MoveSrcSliceWindow(
597 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
598 }
599 else
600 {
601 a_scale_thread_copy.MoveSrcSliceWindow(
602 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
603 }
604 b_scale_thread_copy.Run(b_scale_grid_desc,
605 b_scale_grid_buf,
606 b_scale_thread_desc,
607 make_tuple(I0, I0),
608 b_scale_thread_bufs(local_read_buf));
609
610 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
611 b_scale_thread_copy_step);
612
613 static_for<0, MRepeat, 1>{}([&](auto m0) {
614 vector_type<AccDataType, 2> c_scale_thread_vec;
615 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
616 c_scale_thread_buf[m0];
617 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
618 c_scale_thread_buf[m0];
619
620 static_for<0, NRepeat, 1>{}([&](auto n0) {
621 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
622 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
623 .template AsType<AccDataType>()(Number<t>{}) = 0;
624 });
625 static_for<0, KRepeat, 1>{}([&](auto k0) {
628
629 static_for<0, KPack, 1>{}([&](auto ik) {
630 a_thread_vec.template AsType<ComputeDataType>()(ik) =
631 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
632 make_tuple((m0 + HotloopLocalBufSwitch * mfma_reg_buf) %
633 2,
634 I0,
635 I0,
636 k0,
637 I0,
638 ik))>{}];
639 b_thread_vec.template AsType<ComputeDataType>()(ik) =
640 b_thread_bufs[mfma_reg_buf]
641 [Number<b_thread_desc_.CalculateOffset(
642 make_tuple(n0, I0, k0, ik))>{}];
643 });
644
645 using mfma_input_type =
646 typename vector_type<ComputeDataType,
647 xdlops_gemm.K1PerXdlops>::type;
648
649 xdlops_gemm.template Run<>(
650 a_thread_vec.template AsType<mfma_input_type>(),
651 b_thread_vec.template AsType<mfma_input_type>(),
652 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
653 });
654
655 constexpr index_t c_offset =
656 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
657
658 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
659 using pk_fma_type = typename vector_type<AccDataType, 2>::type;
660
661 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
662 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
663 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
664 .template AsType<pk_fma_type>()[t],
665 c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
666 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
667 .template AsType<pk_fma_type>()[t]);
668 });
669 });
670
671 if constexpr(m0.value == (MRepeat - 2))
672 {
674
675 static_for<0, KRepeat, 1>{}([&](auto k0) {
676 static_for<0, KGroup, 1>{}([&](auto kg0) {
677 a_thread_copy_.Run(
679 make_tuple(Number<(m0 + 2) % MRepeat>{},
680 I0,
681 I0,
683 I0,
684 I0),
685 a_block_buf.At(local_read_buf),
688 Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
689 2>{},
690 I0,
691 I0,
692 k0,
693 I0,
695 a_thread_buf);
696 });
697 });
698 }
699 else if constexpr(m0.value == (MRepeat - 1))
700 {
701 static_for<0, KRepeat, 1>{}([&](auto k0) {
702 static_for<0, KGroup, 1>{}([&](auto kg0) {
703 a_thread_copy_.Run(
705 make_tuple(Number<(m0 + 2) % MRepeat>{},
706 I0,
707 I0,
709 I0,
710 I0),
711 a_block_buf.At(local_read_buf),
714 Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
715 2>{},
716 I0,
717 I0,
718 k0,
719 I0,
721 a_thread_buf);
722 });
723 });
724 }
725 else
726 {
727 static_for<0, KRepeat, 1>{}([&](auto k0) {
728 static_for<0, KGroup, 1>{}([&](auto kg0) {
729 a_thread_copy_.Run(
731 make_tuple(Number<(m0 + 2) % MRepeat>{},
732 I0,
733 I0,
735 I0,
736 I0),
737 a_block_buf.At(mfma_reg_buf),
740 Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
741 2>{},
742 I0,
743 I0,
744 k0,
745 I0,
747 a_thread_buf);
748 });
749 });
750 }
751 });
752
753 static_for<0, MRepeat, 1>{}([&](auto m0) {
754 c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] *
755 b_scale_thread_bufs[mfma_reg_buf][I0];
756 });
757
759 __builtin_amdgcn_sched_barrier(0);
760 };
761
762 LoopFunc(I0, I1);
763 LoopFunc(I1, I0);
764
765 i += 2;
766 } while(i < (num_loop - 2));
767 }
768
769 // tail
770 if constexpr(TailNum == TailNumber::Even)
771 {
772 b_blockwise_copy.Run(b_grid_desc,
773 b_grid_buf,
775 b_block_origin_idx,
776 b_thread_bufs(I1));
777 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
778
779 static_for<0, MRepeat, 1>{}([&](auto m0) {
780 vector_type<AccDataType, 2> c_scale_thread_vec;
781 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
782 c_scale_thread_buf[m0];
783 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
784 c_scale_thread_buf[m0];
785
786 static_for<0, NRepeat, 1>{}([&](auto n0) {
787 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
788 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
789 .template AsType<AccDataType>()(Number<t>{}) = 0;
790 });
791 static_for<0, KRepeat, 1>{}([&](auto k0) {
794
795 static_for<0, KPack, 1>{}([&](auto ik) {
796 a_thread_vec.template AsType<ComputeDataType>()(ik) =
797 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
798 make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
799 b_thread_vec.template AsType<ComputeDataType>()(ik) =
800 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
801 make_tuple(n0, I0, k0, ik))>{}];
802 });
803
804 using mfma_input_type =
805 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
806
807 xdlops_gemm.template Run<>(
808 a_thread_vec.template AsType<mfma_input_type>(),
809 b_thread_vec.template AsType<mfma_input_type>(),
810 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
811 });
812
813 constexpr index_t c_offset =
814 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
815
816 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
817 using pk_fma_type = typename vector_type<AccDataType, 2>::type;
818
819 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
820 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
821 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
822 .template AsType<pk_fma_type>()[t],
823 c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
824 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
825 .template AsType<pk_fma_type>()[t]);
826 });
827 });
828
829 if constexpr(m0.value == (MRepeat - 2))
830 {
832
833 static_for<0, KRepeat, 1>{}([&](auto k0) {
834 static_for<0, KGroup, 1>{}([&](auto kg0) {
835 a_thread_copy_.Run(
837 make_tuple(Number<(m0 + 2) % MRepeat>{},
838 I0,
839 I0,
841 I0,
842 I0),
843 a_block_buf.At(I1),
846 Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
847 a_thread_buf);
848 });
849 });
850 }
851 else if constexpr(m0.value == (MRepeat - 1))
852 {
853 static_for<0, KRepeat, 1>{}([&](auto k0) {
854 static_for<0, KGroup, 1>{}([&](auto kg0) {
855 a_thread_copy_.Run(
857 make_tuple(Number<(m0 + 2) % MRepeat>{},
858 I0,
859 I0,
861 I0,
862 I0),
863 a_block_buf.At(I1),
866 Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
867 a_thread_buf);
868 });
869 });
870 }
871 else
872 {
873 static_for<0, KRepeat, 1>{}([&](auto k0) {
874 static_for<0, KGroup, 1>{}([&](auto kg0) {
875 a_thread_copy_.Run(
877 make_tuple(Number<(m0 + 2) % MRepeat>{},
878 I0,
879 I0,
881 I0,
882 I0),
883 a_block_buf.At(I0),
886 Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
887 a_thread_buf);
888 });
889 });
890 }
891 });
892
894
895 static_for<0, MRepeat, 1>{}([&](auto m0) {
896 c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0];
897 });
898
899 static_for<0, MRepeat, 1>{}([&](auto m0) {
900 vector_type<AccDataType, 2> c_scale_thread_vec;
901 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
902 c_scale_thread_buf[m0];
903 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
904 c_scale_thread_buf[m0];
905
906 static_for<0, NRepeat, 1>{}([&](auto n0) {
907 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
908 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
909 .template AsType<AccDataType>()(Number<t>{}) = 0;
910 });
911 static_for<0, KRepeat, 1>{}([&](auto k0) {
914
915 static_for<0, KPack, 1>{}([&](auto ik) {
916 a_thread_vec.template AsType<ComputeDataType>()(ik) =
917 a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
918 (m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}];
919 b_thread_vec.template AsType<ComputeDataType>()(ik) =
920 b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
921 make_tuple(n0, I0, k0, ik))>{}];
922 });
923
924 using mfma_input_type =
925 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
926
927 xdlops_gemm.template Run<>(
928 a_thread_vec.template AsType<mfma_input_type>(),
929 b_thread_vec.template AsType<mfma_input_type>(),
930 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
931 });
932 constexpr index_t c_offset =
933 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
934
935 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
936 using pk_fma_type = typename vector_type<AccDataType, 2>::type;
937
938 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
939 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
940 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
941 .template AsType<pk_fma_type>()[t],
942 c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
943 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
944 .template AsType<pk_fma_type>()[t]);
945 });
946 });
947
948 if constexpr(m0.value < (MRepeat - 2))
949 {
950 static_for<0, KRepeat, 1>{}([&](auto k0) {
951 static_for<0, KGroup, 1>{}([&](auto kg0) {
952 a_thread_copy_.Run(
956 a_block_buf.At(I1),
958 make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
959 I0,
960 I0,
961 k0,
962 I0,
964 a_thread_buf);
965 });
966 });
967 }
968 });
969 // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
970 // latency
971 // // __builtin_amdgcn_sched_barrier(0);
972 }
973 else
974 {
975 static_for<0, MRepeat, 1>{}([&](auto m0) {
976 vector_type<AccDataType, 2> c_scale_thread_vec;
977 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
978 c_scale_thread_buf[m0];
979 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
980 c_scale_thread_buf[m0];
981
982 static_for<0, NRepeat, 1>{}([&](auto n0) {
983 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
984 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
985 .template AsType<AccDataType>()(Number<t>{}) = 0;
986 });
987 static_for<0, KRepeat, 1>{}([&](auto k0) {
990
991 static_for<0, KPack, 1>{}([&](auto ik) {
992 a_thread_vec.template AsType<ComputeDataType>()(ik) =
993 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
994 make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
995 b_thread_vec.template AsType<ComputeDataType>()(ik) =
996 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
997 make_tuple(n0, I0, k0, ik))>{}];
998 });
999
1000 using mfma_input_type =
1001 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
1002
1003 xdlops_gemm.template Run<>(
1004 a_thread_vec.template AsType<mfma_input_type>(),
1005 b_thread_vec.template AsType<mfma_input_type>(),
1006 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
1007 });
1008 constexpr index_t c_offset =
1009 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
1010
1011 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
1012 using pk_fma_type = typename vector_type<AccDataType, 2>::type;
1013
1014 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
1015 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
1016 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
1017 .template AsType<pk_fma_type>()[t],
1018 c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
1019 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
1020 .template AsType<pk_fma_type>()[t]);
1021 });
1022 });
1023
1024 if constexpr(m0.value < (MRepeat - 2))
1025 {
1026 static_for<0, KRepeat, 1>{}([&](auto k0) {
1027 static_for<0, KGroup, 1>{}([&](auto kg0) {
1028 a_thread_copy_.Run(
1030 make_tuple(
1032 a_block_buf.At(I0),
1034 make_tuple(
1035 Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
1036 a_thread_buf);
1037 });
1038 });
1039 }
1040 });
1041 }
1042 }
1043
1044 protected:
1045 // MRepeat MWave MLane KRepeat KLane KPack
1046 // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
1047 // Reduce the vgpr usage here.
1050
1052 ComputeDataType,
1054 decltype(a_thread_desc_),
1055 Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
1057 5,
1058 A_K1,
1059 A_K1>;
1060
1062
1065
1066 static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
1067
1069};
1070
1071} // namespace ck
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Vgpr
Definition amd_address_space.hpp:20
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops_base.hpp:82
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataType, decltype(a_block_desc_m0_m1_m2_k0_k1_k2), decltype(a_thread_desc_), Sequence< 1, 1, 1, 1, 1, KPack/KGroup >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp:1051
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const CScaleThreadDesc &c_scale_thread_desc, CThreadBuffer &c_thread_buf, const AScaleGridDesc &a_scale_grid_desc, const AScaleThreadDesc &a_scale_thread_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const AScaleThreadTransferStep &a_scale_thread_copy_step, const BScaleGridDesc &b_scale_grid_desc, const BScaleThreadDesc &b_scale_thread_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, const BScaleThreadTransferStep &b_scale_thread_copy_step, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp:381
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack, true > Base
Definition blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp:112
Definition blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp:40
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:75
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition functional2.hpp:33
Definition dtype_vector.hpp:10