gemm_aquant_pipeline_ag_bg_cr_v3.hpp Source File

gemm_aquant_pipeline_ag_bg_cr_v3.hpp Source File#

Composable Kernel: gemm_aquant_pipeline_ag_bg_cr_v3.hpp Source File
gemm_aquant_pipeline_ag_bg_cr_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <string>
7#include <sstream>
8
9#include "ck_tile/core.hpp"
14
15namespace ck_tile {
16
17// Compute optimized pipeline
18// GlobalPrefetchStages: 2
19// LocalPreFillStages: 1
20// LocalPreFetchStages: 1
21// LocalSharedMemoryBuffer: 1
22
23template <typename Problem>
25{
26 template <typename RunFunction>
27 CK_TILE_HOST_DEVICE static auto
28 TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
29 {
30 if(has_hot_loop)
31 {
32 if(tail_number == ck_tile::TailNumber::Full)
33 {
34 return run_func(
37 }
38 else if(tail_number == ck_tile::TailNumber::Odd)
39 {
40 return run_func(
43 }
44 else if(tail_number == ck_tile::TailNumber::Even)
45 {
46 return run_func(
49 }
50 else
51 {
52 throw std::runtime_error("Unsupported tail number for this operation !!!");
53 }
54 }
55 else
56 {
57 if(tail_number == ck_tile::TailNumber::Full)
58 {
59 return run_func(
62 }
63 else if(tail_number == ck_tile::TailNumber::Odd)
64 {
65 return run_func(
68 }
69 else if(tail_number == ck_tile::TailNumber::Even)
70 {
71 return run_func(
74 }
75 else
76 {
77 throw std::runtime_error("Unsupported tail number for this operation !!!");
78 }
79 }
80 }
81};
82
83template <typename Problem, typename Policy = GemmAQuantPipelineAgBgCrDefaultPolicy>
85{
88
95
96 static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!");
97 static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
98
99 using I0 = number<0>;
100 using I1 = number<1>;
101 using I2 = number<2>;
102
103 static constexpr index_t APackedSize =
105 static constexpr index_t BPackedSize =
107
108 static constexpr index_t AQPackedSize =
110
115
117
118 static constexpr index_t BlockSize = Problem::kBlockSize;
119 static constexpr index_t MPerBlock = BlockGemmShape::kM;
120 static constexpr index_t NPerBlock = BlockGemmShape::kN;
121 static constexpr index_t KPerBlock = BlockGemmShape::kK;
122 static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize::kK;
123
124 static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
125 static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
126 static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
127 static constexpr index_t GetVectorSizeAQ()
128 {
129 return Policy::template GetVectorSizeAQ<Problem>();
130 }
131
132 static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
133 static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
134
135 static constexpr bool kPadM = Problem::kPadM;
136 static constexpr bool kPadN = Problem::kPadN;
137 static constexpr bool kPadK = Problem::kPadK;
138
139 static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
140 static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
141
142 static constexpr bool HasHotLoop = Problem::HasHotLoop;
143 static constexpr auto TailNum = Problem::TailNum;
144 static constexpr auto Scheduler = Problem::Scheduler;
145
147
148 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
149 {
150 // clang-format off
151 constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
152 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
153 return concat('_', "aquant_pipeline_AgBgCrCompV3",
155 BlockSize,
156 concat('x', WaveNumM, WaveNumN),
157 concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
158 concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
159 // clang-format on
160 }
161
163 {
164 return Policy::template GetSmemSize<Problem>();
165 }
166
167 CK_TILE_HOST static std::string Print()
168 {
169 constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
170 constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
171 constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
172
173 constexpr index_t WaveSize = 64;
174 constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
175 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
176
177 constexpr index_t A_LDS_Read_Width = GetSmemPackA();
178 constexpr index_t B_LDS_Read_Width = GetSmemPackB();
179
180 constexpr index_t A_LDS_Write_Width = GetSmemPackA();
181 constexpr index_t B_LDS_Write_Width = GetSmemPackB();
182
183 constexpr index_t A_Buffer_Load_Inst_Num =
185 constexpr index_t B_Buffer_Load_Inst_Num =
187 constexpr index_t AQ_Buffer_Load_Inst_Num =
189
190 constexpr index_t A_LDS_Write_Inst_Num =
191 MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
192 constexpr index_t B_LDS_Write_Inst_Num =
193 NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
194
195 constexpr index_t A_LDS_Read_Inst_Num =
196 WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
197 constexpr index_t B_LDS_Read_Inst_Num =
198 WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
199
200 constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
201 (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
202
203 auto str = std::stringstream{};
204
205 str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << ", "
206 << "AQ vector size: " << GetVectorSizeAQ() << "\n"
207 << "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n"
208 << "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num
209 << ", " << "AQ buffer load inst: " << AQ_Buffer_Load_Inst_Num << "\n"
210 << "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num
211 << "\n"
212 << "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
213 << "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
214 << "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
215 << "KPack: " << BlockGemm::Traits::KPack << "\n"
216 << "PrefetchStages: " << PrefetchStages << "\n";
217 return str.str();
218 }
219
220 template <GemmPipelineScheduler Scheduler>
222 {
223 };
224
225 template <>
227 {
229
230 template <bool HasHotLoop,
232 typename ADramBlockWindowTmp,
233 typename BDramBlockWindowTmp,
234 typename AQDramBlockWindowTmp,
235 typename AElementFunction,
236 typename BElementFunction>
237 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
238 const AElementFunction& a_element_func,
239 const BDramBlockWindowTmp& b_dram_block_window_tmp,
240 const BElementFunction& b_element_func,
241 const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
242 index_t m,
243 index_t num_loop,
244 void* p_smem) const
245 {
246 static_assert(
247 std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
248 std::is_same_v<BDataType,
250 std::is_same_v<AQDataType,
252 "A/B/AQ Dram block window should have the same data type as appropriate "
253 "([A|B|AQ]DataType) defined in Problem definition!");
254
255 constexpr bool is_a_col_major =
256 std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
257 constexpr bool is_aq_col_major =
258 std::is_same_v<AQLayout, tensor_layout::gemm::ColumnMajor>;
259 constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
260
261 static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)");
262
263 static_assert(is_a_col_major
264 ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
265 MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
266 : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
267 KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
268 "A block window has incorrect lengths for defined ALayout!");
269 static_assert(is_b_row_major
270 ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
271 NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
272 : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
273 KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
274 "B block window has incorrect lengths for defined BLayout!");
275
276 using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
277 using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
278 using AQDramTileWindowStep = typename AQDramBlockWindowTmp::BottomTensorIndex;
279
280 auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
281
282 constexpr auto a_lds_load_tile_distr =
283 make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
284 constexpr auto b_lds_load_tile_distr =
285 make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
286
287 auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
288 Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
289 auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
290 Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
291 auto aq_copy_dram_window = Base::GetAQDramLoadWindow(aq_dram_block_window_tmp);
292
293 using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
294 using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
295 using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution());
296
297 using ABlockTile =
298 decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
299 using BBlockTile =
300 decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
301 using AQBlockTile =
302 decltype(make_static_distributed_tensor<AQDataType>(AQBlockTileDistr{}));
303
304 auto block_gemm = BlockGemm();
305
306 ABlockTile a_block_tile;
307 BBlockTile b_block_tile;
308 AQBlockTile aq_block_tile[2];
309 int currIdx = 0;
310
311 auto c_block_tile = block_gemm.MakeCBlockTile();
312
313 constexpr ADramTileWindowStep a_dram_tile_window_step =
314 is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
315 constexpr BDramTileWindowStep b_dram_tile_window_step =
316 is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
317
318 // only row_major for AQ
319 const AQDramTileWindowStep aq_dram_tile_window_step =
321 BlockGemm::WarpGemm::kM,
322 0)
324
325 // DRAM prefetch (global read 0)
326 Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
327 Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
329 aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step);
330
331 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
332
333 if constexpr(is_a_col_major)
334 {
336 Policy::template make_shuffled_2d_static_tile_distribution<Problem>());
337 transpose_tile2d(a_shuffle_tmp, a_block_tile);
338 Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
339 }
340 else
341 {
342 Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
343 }
344
345 if constexpr(is_b_row_major)
346 {
348 Policy::template make_shuffled_2d_static_tile_distribution<Problem>());
349 transpose_tile2d(b_shuffle_tmp, b_block_tile);
350 Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
351 }
352 else
353 {
354 Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
355 }
356
357 Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
358 Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
359
361
362 block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
363
364 __builtin_amdgcn_sched_barrier(0);
365
366 if constexpr(HasHotLoop)
367 {
368 index_t i = 0;
369 do
370 {
372
373 if constexpr(is_a_col_major)
374 {
376 Policy::template MakeShuffledARegTileDistribution<Problem>());
377 transpose_tile2d(a_shuffle_tmp, a_block_tile);
378 Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
379 }
380 else
381 {
382 Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
383 }
384 if constexpr(is_b_row_major)
385 {
387 Policy::template MakeShuffledBRegTileDistribution<Problem>());
388 transpose_tile2d(b_shuffle_tmp, b_block_tile);
389 Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
390 }
391 else
392 {
393 Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
394 }
395
396 Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
397 Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
398 Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2],
399 aq_copy_dram_window,
400 aq_dram_tile_window_step);
401
402 block_gemm(
403 c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
404
405 currIdx = (currIdx + 1) % 2;
406
408
409 block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
410 __builtin_amdgcn_sched_barrier(0);
411
412 i += 1;
413 } while(i < (num_loop - 1));
414 }
415 // tail
416 if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
417 {
418 block_gemm(
419 c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
420 }
421 else
422 {
423 Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2],
424 aq_copy_dram_window,
425 aq_dram_tile_window_step);
426 block_gemm(
427 c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
429
430 currIdx = (currIdx + 1) % 2;
431
432 if constexpr(is_a_col_major)
433 {
435 Policy::template MakeShuffledARegTileDistribution<Problem>());
436 transpose_tile2d(a_shuffle_tmp, a_block_tile);
437 Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
438 }
439 else
440 {
441 Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
442 }
443 if constexpr(is_b_row_major)
444 {
446 Policy::template MakeShuffledBRegTileDistribution<Problem>());
447 transpose_tile2d(b_shuffle_tmp, b_block_tile);
448 Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
449 }
450 else
451 {
452 Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
453 }
455 block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
456 block_gemm(
457 c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
458 }
459 return c_block_tile;
460 }
461 };
462 template <typename ADramBlockWindowTmp,
463 typename BDramBlockWindowTmp,
464 typename AQDramBlockWindowTmp>
465 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
466 const BDramBlockWindowTmp& b_dram_block_window_tmp,
467 const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
468 index_t m,
469 index_t num_loop,
470 void* p_smem) const
471 {
472 return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
473 a_dram_block_window_tmp,
474 [](const ADataType& a) { return a; },
475 b_dram_block_window_tmp,
476 [](const BDataType& b) { return b; },
477 aq_dram_block_window_tmp,
478 m,
479 num_loop,
480 p_smem);
481 }
482};
483
484} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
TailNumber
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:21
@ Even
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:24
@ Odd
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:23
@ Full
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:39
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
CK_TILE_DEVICE void transpose_tile2d(OutTensor &out, const InTensor &in)
Definition transpose_tile.hpp:195
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y)
Definition tile/core/numeric/math.hpp:155
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
GemmPipelineScheduler
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:14
@ Intrawave
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:16
CK_TILE_HOST_DEVICE constexpr details::return_type< D, Ts... > make_array(Ts &&... ts)
Definition tile/core/container/array.hpp:242
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
PipelineImplBase Base
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:228
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, const AQDramBlockWindowTmp &aq_dram_block_window_tmp, index_t m, index_t num_loop, void *p_smem) const
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:237
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:222
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:85
static constexpr bool HasHotLoop
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:142
static constexpr index_t GetVectorSizeB()
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:125
remove_cvref_t< typename Problem::CDataType > CDataType
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:92
static constexpr index_t GetSmemPackA()
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:132
static constexpr index_t GetVectorSizeAQ()
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:127
static constexpr auto TailNum
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:143
static constexpr auto Scheduler
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:144
static CK_TILE_HOST std::string Print()
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:167
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:162
remove_cvref_t< typename Problem::AQLayout > AQLayout
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:112
static constexpr index_t KPerBlock
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:121
static constexpr index_t GetSmemPackB()
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:133
static constexpr bool kPadN
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:136
static constexpr index_t APackedSize
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:103
GemmAQuantPipelineAgBgCrImplBase< Problem, Policy > PipelineImplBase
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:87
static constexpr bool DoubleSmemBuffer
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:139
number< 1 > I1
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:100
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:93
static constexpr index_t AQPackedSize
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:108
remove_cvref_t< typename Problem::ADataType > ADataType
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:89
static constexpr index_t NPerBlock
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:120
static constexpr index_t MPerBlock
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:119
static constexpr bool PreshuffleQuant
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:140
remove_cvref_t< decltype(Policy::template GetBlockGemm< Problem >())> BlockGemm
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:116
remove_cvref_t< typename Problem::ALayout > ALayout
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:111
static constexpr bool kPadM
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:135
remove_cvref_t< typename Problem::QuantGroupSize > QuantGroupSize
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:94
BaseGemmPipelineAgBgCrCompV3< Problem > Base
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:86
static constexpr index_t GetVectorSizeA()
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:124
static constexpr index_t BlockSize
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:118
static constexpr index_t PrefetchStages
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:19
static constexpr bool kPadK
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:137
static constexpr index_t KPerBlockAQ
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:122
remove_cvref_t< typename Problem::AQDataType > AQDataType
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:90
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BDramBlockWindowTmp &b_dram_block_window_tmp, const AQDramBlockWindowTmp &aq_dram_block_window_tmp, index_t m, index_t num_loop, void *p_smem) const
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:465
remove_cvref_t< typename Problem::BDataType > BDataType
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:91
number< 0 > I0
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:99
number< 2 > I2
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:101
remove_cvref_t< typename Problem::CLayout > CLayout
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:114
remove_cvref_t< typename Problem::BLayout > BLayout
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:113
static constexpr index_t BPackedSize
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:105
static constexpr index_t GetVectorSizeC()
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:126
static CK_TILE_HOST const std::string GetName()
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:148
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:25
static CK_TILE_HOST_DEVICE auto TailHandler(const RunFunction &run_func, bool has_hot_loop, TailNumber tail_number)
Definition gemm_aquant_pipeline_ag_bg_cr_v3.hpp:28
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:18
static constexpr index_t PrefetchStages
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:19
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:14
static constexpr index_t NPerBlock
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:26
static constexpr index_t KPerBlock
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:27
typename Base::BDataType BDataType
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:18
CK_TILE_DEVICE constexpr auto GetAQDramLoadWindow(const AQDramBlockWindowTmp &aq_dram_block_window_tmp) const
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:37
static constexpr index_t KPerBlockAQ
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:29
static constexpr index_t MPerBlock
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:25
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp &b_dram_block_window_tmp, const BLdsTensorView &b_lds_block_view, const BLdsLoadTileDistr &, const array< index_t, 2 > &offset={0, 0}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:225
CK_TILE_DEVICE auto GetABLdsTensorViews(void *p_smem) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:83
CK_TILE_DEVICE void LocalPrefill(DstTileWindow &lds_tile_window, const SrcBlockTile &src_block_tile, const ElementFunction &element_func) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:57
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp &a_dram_block_window_tmp, const ALdsTensorView &a_lds_block_view, const ALdsLoadTileDistr &, const array< index_t, 2 > &offset={0, 0}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:190
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile &dst_block_tile, SrcTileWindow &dram_tile_window, const DramTileWindowStep &dram_tile_window_step) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:39
Definition tile/core/numeric/integral_constant.hpp:30
Definition tile/core/numeric/numeric.hpp:81