gemm_pipeline_ag_bg_cr_comp_async.hpp Source File

gemm_pipeline_ag_bg_cr_comp_async.hpp Source File#

Composable Kernel: gemm_pipeline_ag_bg_cr_comp_async.hpp Source File
gemm_pipeline_ag_bg_cr_comp_async.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3#pragma once
4#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// A Tile Window: global memory
12// B Tile Window: global memory
13// C Distributed tensor: register
14template <typename Problem>
16{
17 static constexpr index_t PrefetchStages = 2;
18 static constexpr index_t PrefillStages = 1;
19 static constexpr index_t GlobalBufferNum = 1;
20
21 CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
22 {
23 return num_loop > PrefetchStages;
24 }
25
27 {
28 if(num_loop == 1)
29 {
30 return TailNumber::One;
31 }
32 if(num_loop % PrefetchStages == 1)
33 {
34 return TailNumber::Three;
35 }
36 else
37 {
38 return TailNumber::Two;
39 }
40 }
41
42 template <typename RunFunction>
43 CK_TILE_HOST_DEVICE static auto
44 TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
45 {
46 // Handle all the valid cases.
47 if(has_hot_loop)
48 {
49 if(tail_number == TailNumber::Three)
50 {
51 return run_func(bool_constant<true>{},
53 }
54 else if(tail_number == TailNumber::Two)
55 {
56 return run_func(bool_constant<true>{},
58 }
59 }
60 else
61 {
62 if(tail_number == TailNumber::Three)
63 {
64 return run_func(bool_constant<false>{},
66 }
67 else if(tail_number == TailNumber::Two)
68 {
69 return run_func(bool_constant<false>{},
71 }
72 else
73 {
74 return (run_func(bool_constant<false>{},
76 }
77 }
78 // If execution reaches here, it's an invalid tail_number because it wasn't handled above.
79#if defined(__HIP_DEVICE_COMPILE__)
80 __builtin_unreachable();
81#else
82 throw std::logic_error(
83 "Invalid TailNumber: Only TailNumber::Three and TailNumber::Two are supported");
84#endif
85 }
86};
87
94template <typename Problem, typename Policy = GemmPipelineAgBgCrCompAsyncDefaultPolicy>
96{
99
104
108
111
114
117
118 static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
119
120 static constexpr index_t APackedSize =
122 static constexpr index_t BPackedSize =
124
126 using I0 = number<0>;
127 using I1 = number<1>;
128 using I2 = number<2>;
129
130 static constexpr index_t BlockSize = Problem::kBlockSize;
131
132 static constexpr index_t MPerBlock = BlockGemmShape::kM;
133 static constexpr index_t NPerBlock = BlockGemmShape::kN;
134 static constexpr index_t KPerBlock = BlockGemmShape::kK;
135
136 template <bool IsWave32Host = false>
137 static constexpr index_t GetVectorSizeA()
138 {
139 return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
140 }
141 template <bool IsWave32Host = false>
142 static constexpr index_t GetVectorSizeB()
143 {
144 return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
145 }
146 static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
147
148 static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
149 static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
150
151 static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
152 static constexpr index_t Preshuffle = Problem::Preshuffle;
153
154 static constexpr bool kPadM = Problem::kPadM;
155 static constexpr bool kPadN = Problem::kPadN;
156 static constexpr bool kPadK = Problem::kPadK;
157
158 static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
159
160 static constexpr bool HasHotLoop = Problem::HasHotLoop;
161 static constexpr auto TailNum = Problem::TailNum;
162 static constexpr auto Scheduler = Problem::Scheduler;
163
166
168 {
169 return Policy::template GetSmemSize<Problem>();
170 }
171
172 CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
173 {
174 return Policy::template IsTransposeC<Problem>();
175 }
176
177 template <GemmPipelineScheduler Scheduler>
179 {
180 };
181
182 template <>
184 {
186
187 CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
188 {
189 constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{});
190 constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{});
191 constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{});
192
193 constexpr index_t WaveSize = get_warp_size();
194
195 constexpr index_t A_Buffer_Load_Inst_Num =
197 constexpr index_t B_Buffer_Load_Inst_Num =
199
200 constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
201 (BlockSize / WaveSize) /
202 (MPerXDL * NPerXDL * KPerXDL);
203
204 constexpr auto num_buffer_load_inst = A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num;
205 constexpr auto num_issue = num_buffer_load_inst;
206
208 // TODO: this will likely need to be redesigned after (1) changes to reading from
209 // LDS and (2) re-profiling
210 ignore = i;
211 __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA : 1
212 __builtin_amdgcn_sched_group_barrier(
213 LLVMSchedGroupMask::DS_READ, 1, 0); // DS read : 1
214 __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA: 1
215 __builtin_amdgcn_sched_group_barrier(
216 LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read :1
217 __builtin_amdgcn_sched_group_barrier(
218 LLVMSchedGroupMask::MFMA, C_MFMA_Inst_Num / num_issue - 2, 0); // MFMA : 6
219 });
220 __builtin_amdgcn_sched_barrier(0);
221 }
222
223 template <bool HasHotLoop,
225 typename AsDramBlockWindowTmp,
226 typename BsDramBlockWindowTmp,
227 typename AElementFunction,
228 typename BElementFunction,
229 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
231 bool>* = nullptr>
232 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
233 const AElementFunction& a_element_func,
234 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
235 const BElementFunction& b_element_func,
236 index_t num_loop,
237 void* __restrict__ p_smem_0,
238 void* __restrict__ p_smem_1) const
239 {
240 // TODO support multi-ABD
241 static_assert(1 == std::tuple_size_v<AsDramBlockWindowTmp>);
242 static_assert(1 == std::tuple_size_v<BsDramBlockWindowTmp>);
243 using ADramBlockWindowTmp =
244 remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
245 using BDramBlockWindowTmp =
246 remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
247 // TODO currently fused elementwise are not supported
248 ignore = a_element_func;
249 ignore = b_element_func;
250 static_assert(std::is_same_v<remove_cvref_t<decltype(a_element_func)>,
252 static_assert(std::is_same_v<remove_cvref_t<decltype(b_element_func)>,
254 static_assert(
255 std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
256 std::is_same_v<BDataType,
258 "Data Type conflict on A and B matrix input data type.");
259
260 constexpr bool is_a_col_major =
261 std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
262 constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
263
264 static_assert(is_a_col_major
265 ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
266 MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
267 : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
268 KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
269 "A block window has incorrect lengths for defined ALayout!");
270 static_assert(is_b_row_major
271 ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
272 NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
273 : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
274 KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
275 "B block window has incorrect lengths for defined BLayout!");
276
278 // A DRAM tile window(s) for load
279 auto a_tile_windows = generate_tuple(
280 [&](auto idx) {
281 return make_tile_window(
282 a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
284 a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
285 Policy::template MakeADramTileDistribution<Problem>());
286 },
287 number<AsLayout::size()>{});
288 // B DRAM window(s) for load
289 auto b_tile_windows = generate_tuple(
290 [&](auto idx) {
291 return make_tile_window(
292 b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
294 b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
295 Policy::template MakeBDramTileDistribution<Problem>());
296 },
297 number<BsLayout::size()>{});
298
299 // this pipeline has a pair of LDS buffers per logical tile
300 auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
301 auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
302
303 // set up LDS tile shapes
304 constexpr auto a_lds_shape = []() {
305 if constexpr(is_a_load_tr_v)
307 else
309 }();
310
311 constexpr auto b_lds_shape = []() {
312 if constexpr(is_b_load_tr_v)
314 else
316 }();
317
318 // LDS tile windows for storing, one per LDS buffer
319 auto a_copy_lds_window0 = make_tile_window(a_lds_block0, a_lds_shape, {0, 0});
320
321 auto a_copy_lds_window1 = make_tile_window(a_lds_block1, a_lds_shape, {0, 0});
322
323 auto b_copy_lds_window0 = make_tile_window(b_lds_block0, b_lds_shape, {0, 0});
324
325 auto b_copy_lds_window1 = make_tile_window(b_lds_block1, b_lds_shape, {0, 0});
326
327 // initialize DRAM window steps, used to advance the DRAM windows
328 using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
329 using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
330
331 constexpr ADramTileWindowStep a_dram_tile_window_step =
332 is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
333 constexpr BDramTileWindowStep b_dram_tile_window_step =
334 is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
335
336 // read A(0), B(0) from DRAM to LDS window(0)
337 // and advance the DRAM windows
339 a_copy_lds_window0, a_tile_windows[number<0>{}], a_dram_tile_window_step);
341 b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step);
342
343 // initialize block gemm
344 auto block_gemm = BlockGemm();
345
346 // initialize C block tile
347 auto c_block_tile = block_gemm.MakeCBlockTile();
348 clear_tile(c_block_tile);
349
350 // read A(1), B(1) from DRAM to LDS window(1)
351 // and advance the DRAM windows
353 a_copy_lds_window1, a_tile_windows[number<0>{}], a_dram_tile_window_step);
355 b_copy_lds_window1, b_tile_windows[number<0>{}], b_dram_tile_window_step);
356
357 // tile distribution for the register tiles
358 constexpr auto ALdsTileDistr =
359 make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
360 constexpr auto BLdsTileDistr =
361 make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
362
363 using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
364 using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
365
366 // register tiles; double buffering -> a register tile corresponds to a LDS tile window
367 ALdsTile a_block_tile0, a_block_tile1;
368 BLdsTile b_block_tile0, b_block_tile1;
369
370 constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() {
371 if constexpr(is_a_load_tr_v)
374 typename decltype(ALdsTileDistr)::DstrEncode,
375 typename Problem::ADataType>::TransposedDstrEncode{});
376 else
377 return ALdsTileDistr;
378 }();
379 constexpr auto b_lds_input_tile_distr = [BLdsTileDistr]() {
380 if constexpr(is_b_load_tr_v)
383 typename decltype(BLdsTileDistr)::DstrEncode,
384 typename Problem::BDataType>::TransposedDstrEncode{});
385 else
386 return BLdsTileDistr;
387 }();
388
389 // LDS tile windows for reading;
390 // they share the data pointer with the LDS windows for storing
391 // but also associate with a distribution to produce a register tile when reading
392 auto a_lds_ld_window0 =
393 make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
394 auto a_lds_ld_window1 =
395 make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
396 auto b_lds_ld_window0 =
397 make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
398 auto b_lds_ld_window1 =
399 make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
400
401 static_assert(!(is_tile_window_linear_v<decltype(a_lds_ld_window0)>) &&
402 !(is_tile_window_linear_v<decltype(a_lds_ld_window1)>) &&
403 !(is_tile_window_linear_v<decltype(b_lds_ld_window0)>) &&
404 !(is_tile_window_linear_v<decltype(b_lds_ld_window1)>),
405 "LDS windows must not be linear");
406
407 // write to LDS window(0) must complete before the local prefetch
409 // read A(0), B(0) from LDS window(0) to pipeline registers(0)
410 Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
411 Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
412 // LDS window(0) contents are overwritten below by global prefetch, need to sync
414 // read A(2), B(2) from DRAM to LDS window(0)
415 // and advance the DRAM windows
417 a_copy_lds_window0, a_tile_windows[number<0>{}], a_dram_tile_window_step);
419 b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step);
420
421 if(HasHotLoop)
422 {
423 // we have had 3 global prefetches so far, indexed (0, 1, 2).
424 index_t i_global_read = amd_wave_read_first_lane(3);
425 // alternate ping: (read to register tile(1), use register tile(0) as gemm input)
426 // pong: (read to register tile(0), use register tile(1) as gemm input)
427 do
428 {
429 // ping
430 {
431 // read A(i-1), B(i-1) from LDS window(1) to pipeline registers(1)
432 Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
433 Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
434 // LDS window(1) contents are overwritten by global prefetch, need to sync
436 // read A(i), B(i) from DRAM to LDS window(1)
437 // and advance the DRAM windows
438 Base::GlobalPrefetchAsync(a_copy_lds_window1,
439 a_tile_windows[number<0>{}],
440 a_dram_tile_window_step);
441 Base::GlobalPrefetchAsync(b_copy_lds_window1,
442 b_tile_windows[number<0>{}],
443 b_dram_tile_window_step);
444 // C(i-3) = A(i-3) @ B(i-3)
445 block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
447 }
448 // pong
449 {
450 // write to LDS window(0) must complete before the local prefetch
452 // read A(i), B(i) from LDS window(0) to pipeline registers(0)
453 Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
454 Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
455 // LDS window(0) contents are overwritten by global prefetch, need to sync
457 // read A(i+1), B(i+1) from DRAM to LDS window(0)
458 // and advance the DRAM windows
459 Base::GlobalPrefetchAsync(a_copy_lds_window0,
460 a_tile_windows[number<0>{}],
461 a_dram_tile_window_step);
462 Base::GlobalPrefetchAsync(b_copy_lds_window0,
463 b_tile_windows[number<0>{}],
464 b_dram_tile_window_step);
465 // C(i-2) = A(i-2) @ B(i-2)
466 block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
468 }
469 i_global_read += 2;
470 } while(i_global_read < num_loop);
471 }
472
473 // 3 block gemms remaining
474 if constexpr(TailNum == TailNumber::Three)
475 {
476 {
477 // read A(num_loop-1), B(num_loop-1) from LDS window(1) to pipeline registers(1)
478 Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
479 Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
480 // C(num_loop-2) = A(num_loop-2) @ B(num_loop-2)
481 block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
482 }
483 {
484 // write to LDS window(0) must complete before the local prefetch
486 // read A(num_loop), B(num_loop) from LDS window(0) to pipeline registers(0)
487 Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
488 Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
489 // C(num_loop-1) = A(num_loop-1) @ B(num_loop-1)
490 block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
491 }
492 {
493 // C(num_loop) = A(num_loop) @ B(num_loop)
494 block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
495 }
496 }
497 else if(TailNum == TailNumber::Two)
498 // 2 block gemms remaining
499 {
500 {
501 // read A(num_loop), B(num_loop) from LDS window(1) to pipeline registers(1)
502 Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
503 Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
504 // C(num_loop-1) = A(num_loop-1) @ B(num_loop-1)
505 block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
506 }
507 {
508 // C(num_loop) = A(num_loop) @ B(num_loop)
509 block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
510 }
511 }
512 else if(TailNum == TailNumber::One)
513 {
515 block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
516 __builtin_amdgcn_sched_barrier(0);
517 }
518 return c_block_tile;
519 }
520 };
521
522 template <typename ADramBlockWindowTmp,
523 typename BDramBlockWindowTmp,
524 typename AElementFunction,
525 typename BElementFunction>
526 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
527 const AElementFunction& a_element_func,
528 const BDramBlockWindowTmp& b_dram_block_window_tmp,
529 const BElementFunction& b_element_func,
530 index_t num_loop,
531 void* p_smem_0,
532 void* p_smem_1) const
533 {
534 return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
535 a_dram_block_window_tmp,
536 a_element_func,
537 b_dram_block_window_tmp,
538 b_element_func,
539 num_loop,
540 p_smem_0,
541 p_smem_1);
542 }
543
544 public:
545 template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
546 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
547 const BDramBlockWindowTmp& b_dram_block_window_tmp,
548 const index_t num_loop,
549 void* __restrict__ p_smem_0,
550 void* __restrict__ p_smem_1) const
551 {
552 return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
553 a_dram_block_window_tmp,
554 [](const ADataType& a) { return a; },
555 b_dram_block_window_tmp,
556 [](const BDataType& b) { return b; },
557 num_loop,
558 p_smem_0,
559 p_smem_1);
560 }
561};
562} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
TailNumber
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:21
@ One
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:27
@ Two
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:28
@ Three
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:29
TransposeTileDistributionTraits< TileDistributionEncoding_, DataType_, Policy, true > InputTileDistributionTraits
Definition load_tile_transpose.hpp:343
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition type_traits.hpp:67
CK_TILE_DEVICE void block_sync_lds_direct_load()
Definition arch.hpp:288
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
constexpr detail::ignore_t ignore
Definition tile/core/utility/ignore.hpp:20
@ MFMA
Definition arch.hpp:426
@ DS_READ
Definition arch.hpp:431
@ VMEM_READ
Definition arch.hpp:428
constexpr bool is_tile_window_linear_v
Helper variable template to check if a type is a linear tile window.
Definition tile_window_linear.hpp:1119
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
GemmPipelineScheduler
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:14
@ Intrawave
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:16
CK_TILE_HOST_DEVICE constexpr details::return_type< D, Ts... > make_array(Ts &&... ts)
Definition tile/core/container/array.hpp:242
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:16
static constexpr index_t GlobalBufferNum
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:19
static CK_TILE_HOST_DEVICE auto TailHandler(const RunFunction &run_func, bool has_hot_loop, TailNumber tail_number)
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:44
static CK_TILE_HOST constexpr bool BlockHasHotloop(index_t num_loop)
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:21
static CK_TILE_HOST constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:26
static constexpr index_t PrefillStages
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:18
static constexpr index_t PrefetchStages
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:17
PipelineImplBase Base
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:185
static CK_TILE_DEVICE constexpr auto HotLoopScheduler()
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:187
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BsDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, index_t num_loop, void *__restrict__ p_smem_0, void *__restrict__ p_smem_1) const
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:232
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:179
Compute optimized pipeline version async; which is based on V4.
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:96
static constexpr auto TailNum
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:161
static constexpr bool kPadM
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:154
remove_cvref_t< std::tuple_element_t< 0, BsDataType > > BDataType
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:116
static constexpr bool kPadK
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:156
remove_cvref_t< typename Problem::BsLayoutTuple > BsLayout
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:106
GemmPipelineAgBgCrImplBase< Problem, Policy > PipelineImplBase
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:98
static constexpr index_t MPerBlock
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:132
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:103
static constexpr index_t NumWaveGroups
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:151
remove_cvref_t< std::tuple_element_t< 0, BsLayout > > BLayout
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:113
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, index_t num_loop, void *p_smem_0, void *p_smem_1) const
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:526
remove_cvref_t< std::tuple_element_t< 0, AsLayout > > ALayout
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:112
remove_cvref_t< typename Problem::AElementWise > AElementWise
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:109
number< 0 > I0
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:126
static constexpr index_t BlockSize
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:130
remove_cvref_t< typename Problem::CLayout > CLayout
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:107
static constexpr auto is_b_load_tr_v
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:165
static constexpr index_t BPackedSize
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:122
remove_cvref_t< typename Problem::AsDataTypeTuple > AsDataType
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:100
static constexpr index_t KPerBlock
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:134
remove_cvref_t< std::tuple_element_t< 0, AsDataType > > ADataType
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:115
static constexpr index_t GetVectorSizeB()
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:142
remove_cvref_t< typename Problem::BElementWise > BElementWise
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:110
static constexpr auto is_a_load_tr_v
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:164
static constexpr index_t Preshuffle
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:152
static constexpr bool DoubleSmemBuffer
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:158
static constexpr index_t NPerBlock
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:133
remove_cvref_t< decltype(Policy::template GetBlockGemm< Problem >())> BlockGemm
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:125
static constexpr bool kPadN
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:155
remove_cvref_t< typename Problem::BsDataTypeTuple > BsDataType
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:101
static constexpr index_t GetSmemPackA()
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:148
static CK_TILE_HOST_DEVICE constexpr auto IsTransposeC()
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:172
remove_cvref_t< typename Problem::AsLayoutTuple > AsLayout
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:105
static constexpr index_t GetSmemPackB()
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:149
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BDramBlockWindowTmp &b_dram_block_window_tmp, const index_t num_loop, void *__restrict__ p_smem_0, void *__restrict__ p_smem_1) const
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:546
static constexpr index_t APackedSize
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:120
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:167
number< 2 > I2
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:128
BaseGemmPipelineAgBgCrCompAsync< Problem > Base
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:97
number< 1 > I1
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:127
static constexpr index_t GetVectorSizeA()
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:137
static constexpr auto Scheduler
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:162
static constexpr bool HasHotLoop
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:160
remove_cvref_t< typename Problem::CDataType > CDataType
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:102
static constexpr index_t GetVectorSizeC()
Definition gemm_pipeline_ag_bg_cr_comp_async.hpp:146
Definition gemm_pipeline_ag_bg_cr_base.hpp:13
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataType > > BDataType
Definition gemm_pipeline_ag_bg_cr_base.hpp:22
CK_TILE_DEVICE auto GetABLdsTensorViews(void *p_smem) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:83
static constexpr index_t NPerBlock
Definition gemm_pipeline_ag_bg_cr_base.hpp:26
CK_TILE_DEVICE void GlobalPrefetchAsync(DstBlockWindow &dst_block_window, SrcTileWindow &dram_tile_window, const DramTileWindowStep &dram_tile_window_step) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:48
CK_TILE_DEVICE void LocalPrefetch(DstBlockTile &dst_block_tile, const SrcTileWindow &lds_tile_window, bool_constant< LoadTranspose >={}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:73
static constexpr index_t MPerBlock
Definition gemm_pipeline_ag_bg_cr_base.hpp:25
static constexpr index_t KPerBlock
Definition gemm_pipeline_ag_bg_cr_base.hpp:27
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:437
Definition tile/core/numeric/integral_constant.hpp:30
Definition tile/core/numeric/numeric.hpp:81
Definition tile/core/utility/functional.hpp:43