gemm_pipeline_ag_bg_cr_comp_v6.hpp Source File

gemm_pipeline_ag_bg_cr_comp_v6.hpp Source File#

Composable Kernel: gemm_pipeline_ag_bg_cr_comp_v6.hpp Source File
gemm_pipeline_ag_bg_cr_comp_v6.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3#pragma once
4#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// A Tile Window: global memory
12// B Tile Window: global memory
13// C Distributed tensor: register
14template <typename Problem>
16{
17 static constexpr index_t PrefetchStages = 3;
18 static constexpr index_t PrefillStages = 1;
19 static constexpr index_t GlobalBufferNum = 2;
20 static constexpr index_t HotloopUnroll = 2;
21
22 CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
23
24 CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
25 {
26 return num_loop > PrefetchStages;
27 }
28
30 {
31 if(num_loop % HotloopUnroll == 1)
32 {
33 return TailNumber::Odd;
34 }
35 else
36 {
37 return TailNumber::Even;
38 }
39 }
40
41 template <typename RunFunction>
42 CK_TILE_HOST_DEVICE static auto
43 TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
44 {
45 // Handle all the valid cases.
46 if(has_hot_loop)
47 {
48 if(tail_number == TailNumber::Odd)
49 {
50 return run_func(bool_constant<true>{},
52 }
53 else if(tail_number == TailNumber::Even)
54 {
55 return run_func(bool_constant<true>{},
57 }
58 }
59 else
60 {
61 if(tail_number == TailNumber::Odd)
62 {
63 return run_func(bool_constant<false>{},
65 }
66 else if(tail_number == TailNumber::Even)
67 {
68 return run_func(bool_constant<false>{},
70 }
71 }
72 // If execution reaches here, it's an invalid tail_number because it wasn't handled above.
73#if defined(__HIP_DEVICE_COMPILE__)
74 __builtin_unreachable();
75#else
76 throw std::logic_error("Invalid TailNumber: Only TailNumber::Odd and TailNumber::Even are "
77 "supported in this pipeline context.");
78#endif
79 }
80};
81
82// Compute optimized pipeline
83// GlobalPrefetchStages: 3
84// LocalPreFillStages: 1
85// LocalPreFetchStages: 1
86// LocalSharedMemoryBuffer: 2
87template <typename Problem, typename Policy = GemmPipelineAgBgCrCompV6DefaultPolicy>
89{
92
97
100
104
107
110
111 static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
112
113 static constexpr index_t APackedSize =
115 static constexpr index_t BPackedSize =
117
118 static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
119
121 static constexpr auto I0 = number<0>{};
122 static constexpr auto I1 = number<1>{};
123 static constexpr auto I2 = number<2>{};
124
125 static constexpr index_t BlockSize = Problem::kBlockSize;
126
127 static constexpr index_t MPerBlock = BlockGemmShape::kM;
128 static constexpr index_t NPerBlock = BlockGemmShape::kN;
129 static constexpr index_t KPerBlock = BlockGemmShape::kK;
130
131 template <bool IsWave32Host = false>
132 static constexpr index_t GetVectorSizeA()
133 {
134 return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
135 }
136 template <bool IsWave32Host = false>
137 static constexpr index_t GetVectorSizeB()
138 {
139 return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
140 }
141 static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
142
143 static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
144 static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
145
146 static constexpr index_t KRepeat = BlockGemm::WarpGemm::kKPerThread / GetSmemPackA();
147
148 static constexpr bool kPadM = Problem::kPadM;
149 static constexpr bool kPadN = Problem::kPadN;
150 static constexpr bool kPadK = Problem::kPadK;
151
152 static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
153 static constexpr index_t Preshuffle = Problem::Preshuffle;
154
155 static constexpr bool HasHotLoop = Problem::HasHotLoop;
156 static constexpr auto TailNum = Problem::TailNum;
157 static constexpr auto Scheduler = Problem::Scheduler;
158
161
162 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
163 {
164 // clang-format off
165 return concat('_', "pipeline_AgBgCrCompV6", BlockSize,
167 concat('x', kPadM, kPadN, kPadK),
168 concat('x', TailNum),
169 concat('_', KRepeat),
171 concat('_', Preshuffle),
172 concat('_', HasHotLoop));
173 // clang-format on
174 }
175
177 {
178 return Policy::template GetSmemSize<Problem>();
179 }
180
181 CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
182 {
183 return Policy::template IsTransposeC<Problem>();
184 }
185
186 template <GemmPipelineScheduler Scheduler>
187 struct PipelineImpl : public BasePImpl
188 {
189 };
190
191 template <>
193 {
194 CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
195 {
196 constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0);
197 constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1);
198 constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2);
199
200 constexpr index_t WaveSize = 64;
201 constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0);
202 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1);
203
204 constexpr index_t A_LDS_Read_Width = KPerXDL;
205 constexpr index_t B_LDS_Read_Width = KPerXDL;
206
207 constexpr index_t A_Buffer_Load_Inst_Num =
209 constexpr index_t B_Buffer_Load_Inst_Num =
211
212 constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
213 constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
214
215 constexpr index_t A_LDS_Read_Inst_Num =
216 WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
217 constexpr index_t B_LDS_Read_Inst_Num =
218 WaveNumM * NPerBlock * KPerBlock / (BlockSize * KPerXDL);
219
220 constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
221 (BlockSize / WaveSize) /
222 (MPerXDL * NPerXDL * KPerXDL);
223
224 constexpr auto num_ds_read_inst_a =
225 A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
226 : A_LDS_Read_Inst_Num / 2;
227 constexpr auto num_ds_read_inst_b =
228 B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num
229 : B_LDS_Read_Inst_Num / 2;
230
231 constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
232
233 constexpr auto ds_read_a_issue_cycle =
234 A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4;
235 constexpr auto ds_read_b_issue_cycle =
236 B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4;
237
238 constexpr auto ds_read_a_mfma_rate =
239 (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
240 constexpr auto ds_read_b_mfma_rate =
241 (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
242
243 constexpr auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1);
244 constexpr auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1);
245 constexpr auto num_dsread_stage3_a = num_ds_read_inst_a / KRepeat;
246 constexpr auto num_dsread_stage3_b = num_ds_read_inst_b / KRepeat;
247
248 constexpr auto num_dsread_stage1_a_mfma =
249 (num_dsread_stage1_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
250 constexpr auto num_dsread_stage1_b_mfma =
251 (num_dsread_stage1_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
252 constexpr auto num_dsread_stage3_a_mfma =
253 (num_dsread_stage3_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
254 constexpr auto num_dsread_stage3_b_mfma =
255 (num_dsread_stage3_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
256
257 constexpr auto num_mfma_stage2 = C_MFMA_Inst_Num -
258 num_ds_read_inst_a / ds_read_a_mfma_rate -
259 num_ds_read_inst_b / ds_read_b_mfma_rate;
260 constexpr auto num_mfma_per_issue =
261 num_mfma_stage2 / (A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num);
262 constexpr auto num_dswrite_per_issue_a = A_LDS_Write_Inst_Num / A_Buffer_Load_Inst_Num;
263 constexpr auto num_dswrite_per_issue_b = B_LDS_Write_Inst_Num / B_Buffer_Load_Inst_Num;
264
265 // stage 1
267 ignore = i;
268 if constexpr((num_dsread_stage1_a - (i + 1) * ds_read_a_mfma_rate) >=
269 ds_read_a_mfma_rate)
270 {
271 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
272 }
273 else
274 {
275 __builtin_amdgcn_sched_group_barrier(
276 0x100,
277 num_dsread_stage1_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate,
278 0); // DS read
279 }
280 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
281 });
283 ignore = i;
284 if constexpr((num_dsread_stage1_b - (i + 1) * ds_read_b_mfma_rate) >=
285 ds_read_b_mfma_rate)
286 {
287 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
288 }
289 else
290 {
291 __builtin_amdgcn_sched_group_barrier(
292 0x100,
293 num_dsread_stage1_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate,
294 0); // DS read
295 }
296 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
297 });
298
299 // stage 2
301 ignore = i;
302 static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
303 ignore = idswrite;
304 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
305 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
306 });
307 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
308 __builtin_amdgcn_sched_group_barrier(
309 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
310 });
312 ignore = i;
313 static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
314 ignore = idswrite;
315 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
316 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
317 });
318 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
319 __builtin_amdgcn_sched_group_barrier(
320 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
321 });
322
323 // stage 3
325 ignore = i;
326 if constexpr((num_dsread_stage3_a - (i + 1) * ds_read_a_mfma_rate) >=
327 ds_read_a_mfma_rate)
328 {
329 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
330 }
331 else
332 {
333 __builtin_amdgcn_sched_group_barrier(
334 0x100,
335 num_dsread_stage3_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate,
336 0); // DS read
337 }
338 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
339 });
341 ignore = i;
342 if constexpr((num_dsread_stage3_b - (i + 1) * ds_read_b_mfma_rate) >=
343 ds_read_b_mfma_rate)
344 {
345 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
346 }
347 else
348 {
349 __builtin_amdgcn_sched_group_barrier(
350 0x100,
351 num_dsread_stage3_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate,
352 0); // DS read
353 }
354 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
355 });
356 __builtin_amdgcn_sched_barrier(0);
357 }
358
359 template <bool HasHotLoop,
361 typename AsDramBlockWindowTmp,
362 typename BsDramBlockWindowTmp,
363 typename AElementFunction,
364 typename BElementFunction,
365 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
367 bool>* = nullptr>
368 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
369 const AElementFunction& a_element_func,
370 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
371 const BElementFunction& b_element_func,
372 index_t num_loop,
373 void* __restrict__ p_smem) const
374 {
375 // TODO: Add Multi A/B support
376 static_assert(std::tuple_size<remove_cvref_t<AsDramBlockWindowTmp>>::value == 1,
377 "Multi A/B is not yet supported for this pipeline.");
378 static_assert(std::tuple_size<remove_cvref_t<BsDramBlockWindowTmp>>::value == 1,
379 "Multi A/B is not yet supported for this pipeline.");
380
381 using ADramBlockWindowTmp =
382 remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
383 using BDramBlockWindowTmp =
384 remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
385 static_assert(
386 std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
387 std::is_same_v<BDataType,
389 "Data Type conflict on A and B matrix input data type.");
390
391 constexpr bool is_a_col_major =
392 std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
393 constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
394
395 static_assert(is_a_col_major
396 ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0] &&
397 MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1])
398 : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0] &&
399 KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1]),
400 "A block window has incorrect lengths for defined ALayout!");
401 static_assert(is_b_row_major
402 ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0] &&
403 NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1])
404 : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0] &&
405 KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1]),
406 "B block window has incorrect lengths for defined BLayout!");
407
409 using ALdsType =
410 remove_cvref_t<decltype(BasePImpl::GetABLdsTensorViews(p_smem).at(I0))>;
411 using BLdsType =
412 remove_cvref_t<decltype(BasePImpl::GetABLdsTensorViews(p_smem).at(I1))>;
413 auto&& ABLdsTensorViews = BasePImpl::GetABLdsTensorViews(p_smem);
414 ALdsType& a_lds_block = ABLdsTensorViews.at(I0);
415 BLdsType& b_lds_block = ABLdsTensorViews.at(I1);
416
417 // Tile distribution for load from lds
418 constexpr auto a_lds_load_tile_distr =
419 make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
420 constexpr auto b_lds_load_tile_distr =
421 make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
422
423 using acopy_dram_type =
424 remove_cvref_t<decltype(BasePImpl::GetAWindows(a_dram_block_window_tmp,
425 a_lds_block,
426 a_lds_load_tile_distr)
427 .at(I0))>;
428 using bcopy_dram_type =
429 remove_cvref_t<decltype(BasePImpl::GetBWindows(b_dram_block_window_tmp,
430 b_lds_block,
431 b_lds_load_tile_distr)
432 .at(I0))>;
433
434 using a_copy_lds_window_type =
435 remove_cvref_t<decltype(BasePImpl::GetAWindows(a_dram_block_window_tmp,
436 a_lds_block,
437 a_lds_load_tile_distr)
438 .at(I1))>;
439 using b_copy_lds_window_type =
440 remove_cvref_t<decltype(BasePImpl::GetBWindows(b_dram_block_window_tmp,
441 b_lds_block,
442 b_lds_load_tile_distr)
443 .at(I1))>;
444
445 using a_lds_load_tile_distr_type =
446 remove_cvref_t<decltype(BasePImpl::GetAWindows(a_dram_block_window_tmp,
447 a_lds_block,
448 a_lds_load_tile_distr)
449 .at(I2))>;
450 using b_lds_load_tile_distr_type =
451 remove_cvref_t<decltype(BasePImpl::GetBWindows(b_dram_block_window_tmp,
452 b_lds_block,
453 b_lds_load_tile_distr)
454 .at(I2))>;
455
456 auto&& aWindows =
457 BasePImpl::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
458 auto&& bWindows =
459 BasePImpl::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
460
461 // A DRAM tile window for load
462 // A LDS tile window for store
463 // A LDS tile for block GEMM
464 acopy_dram_type& a_copy_dram_window = aWindows.at(I0);
465 a_copy_lds_window_type& a_copy_lds_window = aWindows.at(I1);
466 a_lds_load_tile_distr_type& a_lds_gemm_window = aWindows.at(I2);
467
468 // B DRAM tile window for load
469 // B LDS tile window for store
470 // B LDS tile for block GEMM
471 bcopy_dram_type& b_copy_dram_window = bWindows.at(I0);
472 b_copy_lds_window_type& b_copy_lds_window = bWindows.at(I1);
473 b_lds_load_tile_distr_type& b_lds_gemm_window = bWindows.at(I2);
474
475 // Block GEMM
476 auto block_gemm = BlockGemm();
477 auto c_block_tile = block_gemm.MakeCBlockTile();
478
479 using ABlockTileDistr =
480 decltype(a_copy_dram_window[number<0>{}].get_tile_distribution());
481 using BBlockTileDistr =
482 decltype(b_copy_dram_window[number<0>{}].get_tile_distribution());
483
484 using ABlockTile =
485 decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
486 using BBlockTile =
487 decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
488
489 ABlockTile a_block_tile[Base::GlobalBufferNum];
490 BBlockTile b_block_tile[Base::GlobalBufferNum];
491
492 using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
493 using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
494
495 constexpr ADramTileWindowStep a_dram_tile_window_step =
496 is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
497 constexpr BDramTileWindowStep b_dram_tile_window_step =
498 is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
499
500 constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution(
501 BlockGemm::MakeABlockDistributionEncode())){};
502 constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(
503 BlockGemm::MakeBBlockDistributionEncode())){};
504
505 using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
506 using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
507
508 ALdsTile a_lds_tile;
509 BLdsTile b_lds_tile;
510 // -----------------------------------------------------------------------------------------
511 // Gemm pipeline start
512
513 // Global prefetch 1
514 a_block_tile[I0] = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
515 move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
516 b_block_tile[I0] = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
517 move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
518
519 // initialize C
520 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
521
522 // Local prefill 1
523 if constexpr(is_a_col_major && !is_a_load_tr_v())
524 {
526 Policy::template MakeShuffledARegTileDistribution<Problem>());
527 transpose_tile2d(a_shuffle_tmp, a_block_tile[I0]);
528 BasePImpl::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
529 }
530 else
531 {
532 BasePImpl::LocalPrefill(a_copy_lds_window, a_block_tile[I0]);
533 }
534 if constexpr(is_b_row_major && !is_b_load_tr_v())
535 {
537 Policy::template MakeShuffledBRegTileDistribution<Problem>());
538 transpose_tile2d(b_shuffle_tmp, b_block_tile[I0]);
539 BasePImpl::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
540 }
541 else
542 {
543 BasePImpl::LocalPrefill(b_copy_lds_window, b_block_tile[I0]);
544 }
545
546 // Global prefetch 2
547 a_block_tile[I0] = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
548 move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
549 b_block_tile[I0] = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
550 move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
551
552 // Global prefetch 3
553 a_block_tile[I1] = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
554 move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
555 b_block_tile[I1] = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
556 move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
557
559
560 // Local prefetch 1
561 BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v);
562 BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v);
563
564 if(HasHotLoop)
565 {
566 index_t i = 0;
567 do
568 {
569 auto LoopFunc = [&](auto vmem_buf_idx) {
570 static_for<0, KRepeat, 1>{}([&](auto k0) {
571 if constexpr(k0 == (KRepeat - 1))
572 {
574
575 // Local prefill 2
576 if constexpr(is_a_col_major && !is_a_load_tr_v())
577 {
579 Policy::template MakeShuffledARegTileDistribution<
580 Problem>());
581 transpose_tile2d(a_shuffle_tmp, a_block_tile[vmem_buf_idx]);
582 BasePImpl::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
583 }
584 else
585 {
586 BasePImpl::LocalPrefill(a_copy_lds_window,
587 a_block_tile[vmem_buf_idx]);
588 }
589 if constexpr(is_b_row_major && !is_b_load_tr_v())
590 {
592 Policy::template MakeShuffledBRegTileDistribution<
593 Problem>());
594 transpose_tile2d(b_shuffle_tmp, b_block_tile[vmem_buf_idx]);
595 BasePImpl::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
596 }
597 else
598 {
599 BasePImpl::LocalPrefill(b_copy_lds_window,
600 b_block_tile[vmem_buf_idx]);
601 }
602
603 // Global prefetch 4
604 a_block_tile[vmem_buf_idx] =
605 load_tile_with_elementwise(a_copy_dram_window, a_element_func);
606 move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
607 b_block_tile[vmem_buf_idx] =
608 load_tile_with_elementwise(b_copy_dram_window, b_element_func);
609 move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
610
612 }
613 block_gemm(c_block_tile, a_lds_tile, b_lds_tile);
614
615 // Local prefetch 2
616 BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v);
617 BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v);
618 });
619
621 };
622
623 LoopFunc(I0);
624 LoopFunc(I1);
625
627 } while(i < (num_loop - Base::PrefetchStages));
628 }
629
630 auto ReadWriteCompFunc = [&](auto vmem_buf_idx) {
631 static_for<0, KRepeat, 1>{}([&](auto k0) {
632 if constexpr(k0 == (KRepeat - 1))
633 {
635
636 // Local prefill 3
637 if constexpr(is_a_col_major && !is_a_load_tr_v())
638 {
640 Policy::template MakeShuffledARegTileDistribution<Problem>());
641 transpose_tile2d(a_shuffle_tmp, a_block_tile[vmem_buf_idx]);
642 BasePImpl::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
643 }
644 else
645 {
646 BasePImpl::LocalPrefill(a_copy_lds_window, a_block_tile[vmem_buf_idx]);
647 }
648 if constexpr(is_b_row_major && !is_b_load_tr_v())
649 {
651 Policy::template MakeShuffledBRegTileDistribution<Problem>());
652 transpose_tile2d(b_shuffle_tmp, b_block_tile[vmem_buf_idx]);
653 BasePImpl::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
654 }
655 else
656 {
657 BasePImpl::LocalPrefill(b_copy_lds_window, b_block_tile[vmem_buf_idx]);
658 }
659
661 }
662
663 block_gemm(c_block_tile, a_lds_tile, b_lds_tile);
664
665 BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v);
666 BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v);
667 });
668
670 };
671
672 auto ReadCompFunc = [&]() {
673 static_for<0, KRepeat - 1, 1>{}([&]() {
674 __syncthreads();
675 block_gemm(c_block_tile, a_lds_tile, b_lds_tile);
676
677 // Local prefetch 4
678 BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v);
679 BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v);
680
681 __syncthreads();
682 });
683
684 block_gemm(c_block_tile, a_lds_tile, b_lds_tile);
685
687 };
688
689 if constexpr(TailNum == TailNumber::Odd)
690 {
691 ReadWriteCompFunc(I0);
692 ReadWriteCompFunc(I1);
693 ReadCompFunc();
694 }
695 else if constexpr(TailNum == TailNumber::Even)
696 {
697 ReadWriteCompFunc(I0);
698 ReadCompFunc();
699 }
700
701 return c_block_tile;
702 }
703 };
704
705 public:
706 template <typename AsDramBlockWindowTmp,
707 typename BsDramBlockWindowTmp,
708 typename AElementFunction,
709 typename BElementFunction,
710 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
712 bool>* = nullptr>
713 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
714 const AElementFunction& a_element_func,
715 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
716 const BElementFunction& b_element_func,
717 index_t num_loop,
718 void* __restrict__ p_smem) const
719 {
720 return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
721 a_dram_block_window_tmp,
722 a_element_func,
723 b_dram_block_window_tmp,
724 b_element_func,
725 num_loop,
726 p_smem);
727 }
728
729 template <typename AsDramBlockWindowTmp,
730 typename BsDramBlockWindowTmp,
731 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
733 bool>* = nullptr>
734 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
735 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
736 const index_t num_loop,
737 void* __restrict__ p_smem) const
738 {
739 return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
740 a_dram_block_window_tmp,
741 [](auto& e, const ADataType& a) { e = a; },
742 b_dram_block_window_tmp,
743 [](auto& e, const BDataType& b) { e = b; },
744 num_loop,
745 p_smem);
746 }
747
748 template <typename ADramBlockWindowTmp,
749 typename BDramBlockWindowTmp,
750 typename AElementFunction,
751 typename BElementFunction,
752 typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
754 bool>* = nullptr>
755 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
756 const AElementFunction& a_element_func,
757 const BDramBlockWindowTmp& b_dram_block_window_tmp,
758 const BElementFunction& b_element_func,
759 index_t num_loop,
760 void* __restrict__ p_smem) const
761 {
762 return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
763 a_element_func,
764 ck_tile::make_tuple(b_dram_block_window_tmp),
765 b_element_func,
766 num_loop,
767 p_smem);
768 }
769};
770} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
TailNumber
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:21
@ Even
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:24
@ Odd
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:23
CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_ &tile_window, ElementWise_ elementwise, number< i_access >={}, bool_constant< oob_conditional_check >={})
Load tile with elementwise function.
Definition load_tile.hpp:41
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition type_traits.hpp:67
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
CK_TILE_DEVICE void transpose_tile2d(OutTensor &out, const InTensor &in)
Definition transpose_tile.hpp:195
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
constexpr detail::ignore_t ignore
Definition tile/core/utility/ignore.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
GemmPipelineScheduler
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:14
@ Intrawave
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:16
CK_TILE_HOST_DEVICE constexpr details::return_type< D, Ts... > make_array(Ts &&... ts)
Definition tile/core/container/array.hpp:242
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:16
static CK_TILE_HOST_DEVICE auto TailHandler(const RunFunction &run_func, bool has_hot_loop, TailNumber tail_number)
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:43
static CK_TILE_HOST_DEVICE constexpr auto TransposeC()
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:22
static constexpr index_t PrefetchStages
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:17
static constexpr index_t PrefillStages
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:18
static CK_TILE_HOST constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:29
static constexpr index_t HotloopUnroll
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:20
static constexpr index_t GlobalBufferNum
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:19
static CK_TILE_HOST constexpr bool BlockHasHotloop(index_t num_loop)
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:24
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BsDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, index_t num_loop, void *__restrict__ p_smem) const
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:368
static CK_TILE_DEVICE constexpr auto HotLoopScheduler()
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:194
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:188
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:89
static constexpr bool HasHotLoop
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:155
static constexpr bool kPadN
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:149
static constexpr auto I0
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:121
static constexpr auto Scheduler
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:157
static constexpr bool kPadM
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:148
static constexpr index_t GetSmemPackB()
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:144
static constexpr auto TailNum
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:156
static constexpr index_t BPackedSize
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:115
BaseGemmPipelineAgBgCrCompV6< Problem > Base
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:90
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, index_t num_loop, void *__restrict__ p_smem) const
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:755
static constexpr index_t KRepeat
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:146
remove_cvref_t< std::tuple_element_t< 0, BsLayout > > BLayout
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:106
static constexpr auto is_b_load_tr_v
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:160
GemmPipelineAgBgCrImplBase< Problem, Policy > BasePImpl
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:91
static constexpr index_t GetSmemPackA()
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:143
static CK_TILE_HOST_DEVICE constexpr auto IsTransposeC()
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:181
static constexpr auto is_a_load_tr_v
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:159
remove_cvref_t< typename Problem::AElementWise > AElementWise
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:98
remove_cvref_t< typename Problem::CLayout > CLayout
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:103
static constexpr auto I2
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:123
remove_cvref_t< decltype(Policy::template GetBlockGemm< Problem >())> BlockGemm
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:120
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:176
remove_cvref_t< typename Problem::BsLayoutTuple > BsLayout
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:102
remove_cvref_t< typename Problem::BsDataTypeTuple > BsDataType
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:94
static constexpr index_t KPerBlock
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:129
remove_cvref_t< std::tuple_element_t< 0, AsDataType > > ADataType
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:108
static constexpr index_t NumWaveGroups
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:118
remove_cvref_t< typename Problem::BElementWise > BElementWise
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:99
static constexpr auto I1
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:122
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const BsDramBlockWindowTmp &b_dram_block_window_tmp, const index_t num_loop, void *__restrict__ p_smem) const
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:734
remove_cvref_t< typename Problem::AsLayoutTuple > AsLayout
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:101
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:96
static constexpr index_t BlockSize
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:125
static constexpr index_t APackedSize
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:113
static CK_TILE_HOST const std::string GetName()
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:162
static constexpr index_t GetVectorSizeA()
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:132
static constexpr index_t GetVectorSizeB()
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:137
static constexpr index_t GetVectorSizeC()
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:141
remove_cvref_t< std::tuple_element_t< 0, BsDataType > > BDataType
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:109
static constexpr index_t NPerBlock
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:128
remove_cvref_t< typename Problem::CDataType > CDataType
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:95
static constexpr index_t MPerBlock
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:127
static constexpr bool kPadK
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:150
static constexpr bool DoubleSmemBuffer
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:152
remove_cvref_t< typename Problem::AsDataTypeTuple > AsDataType
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:93
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BsDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, index_t num_loop, void *__restrict__ p_smem) const
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:713
static constexpr index_t Preshuffle
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:153
remove_cvref_t< std::tuple_element_t< 0, AsLayout > > ALayout
Definition gemm_pipeline_ag_bg_cr_comp_v6.hpp:105
Definition gemm_pipeline_ag_bg_cr_base.hpp:13
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp &b_dram_block_window_tmp, const BLdsTensorView &b_lds_block_view, const BLdsLoadTileDistr &, const array< index_t, 2 > &offset={0, 0}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:225
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataType > > BDataType
Definition gemm_pipeline_ag_bg_cr_base.hpp:22
CK_TILE_DEVICE auto GetABLdsTensorViews(void *p_smem) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:83
static constexpr index_t NPerBlock
Definition gemm_pipeline_ag_bg_cr_base.hpp:26
CK_TILE_DEVICE void LocalPrefetch(DstBlockTile &dst_block_tile, const SrcTileWindow &lds_tile_window, bool_constant< LoadTranspose >={}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:73
static constexpr index_t MPerBlock
Definition gemm_pipeline_ag_bg_cr_base.hpp:25
CK_TILE_DEVICE void LocalPrefill(DstTileWindow &lds_tile_window, const SrcBlockTile &src_block_tile, const ElementFunction &element_func) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:57
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp &a_dram_block_window_tmp, const ALdsTensorView &a_lds_block_view, const ALdsLoadTileDistr &, const array< index_t, 2 > &offset={0, 0}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:190
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataType > > ADataType
Definition gemm_pipeline_ag_bg_cr_base.hpp:20
static constexpr index_t KPerBlock
Definition gemm_pipeline_ag_bg_cr_base.hpp:27
Definition tile/core/numeric/integral_constant.hpp:30
Definition tile/core/numeric/numeric.hpp:81
Definition tile/core/utility/functional.hpp:43