mixed_prec_flatmm_kernel.hpp Source File

mixed_prec_flatmm_kernel.hpp Source File#

Composable Kernel: mixed_prec_flatmm_kernel.hpp Source File
mixed_prec_flatmm_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <string>
8
9#include "ck_tile/core.hpp"
11
13
14namespace ck_tile {
15
16template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
17struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>
18{
20
31 static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
32 static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
33
36 // Below type is actually accumulation data type - the output of block GEMM.
38
40 static constexpr int N_Pack = 2;
41
42 static constexpr index_t NumDTensor = DsDataType::size();
43
44 static constexpr auto I0 = number<0>();
45 static constexpr auto I1 = number<1>();
46 static constexpr auto I2 = number<2>();
47 static constexpr auto I3 = number<3>();
48 static constexpr auto I4 = number<4>();
49
50 static_assert(DsLayout::size() == DsDataType::size(),
51 "The size of DsLayout and DsDataType should be the same");
52 // using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
53
54 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
55 {
56 // clang-format off
57 return concat('_', "mixed_prec_gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
58 // clang-format on
59 }
60
61 template <class ScaleM, class ScaleN>
62 CK_TILE_HOST static constexpr auto
63 GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
64 {
65 if constexpr(UsePersistentKernel)
66 {
67 hipDeviceProp_t prop;
68 int deviceId = 0; // default device
69
70 constexpr int block_size = F16xMXF4FlatmmKernel::BlockSize().x;
71 int dync_smem_size = 0;
72 int maxActiveBlocksPerCU = 0;
73
74 [[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
75
76 e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
77 &maxActiveBlocksPerCU,
78 reinterpret_cast<void*>(
79 kentry<1,
81 FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
82 block_size,
83 dync_smem_size);
84
85 const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
86 const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
87
88 // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
89 // << ", persistent_block_size: " << persistent_block_size
90 // << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
91
92 assert(kargs.k_batch == 1);
93 return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
94 }
95 else
96 {
97 return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
98 }
99 }
100
102
103 template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
104 CK_TILE_DEVICE static auto
106 const BDataType* b_flat_ptr,
107 const std::array<const void*, NumDTensor>& ds_ptr,
108 EDataType* e_ptr,
109 const KernelArgs& kargs,
110 const SplitKBatchOffset& splitk_batch_offset)
111 {
112 const auto& a_tensor_view = [&]() {
113 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
114 {
116 a_ptr,
117 make_tuple(kargs.M, splitk_batch_offset.splitted_k),
118 make_tuple(kargs.stride_A, 1),
119 number<FlatmmPipeline::GetVectorSizeA()>{},
120 number<1>{});
121 }
122 else
123 {
125 a_ptr,
126 make_tuple(splitk_batch_offset.splitted_k, kargs.M),
127 make_tuple(kargs.stride_A, 1),
128 number<FlatmmPipeline::GetVectorSizeA()>{},
129 number<1>{});
130 }
131 }();
132
133 index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
134 index_t kFlatN = kargs.N * kargs.K / kFlatK;
135
136 const auto& b_flat_tensor_view = [&]() {
138 b_flat_ptr,
139 make_tuple(kFlatN, kFlatK),
140 make_tuple(kFlatK, 1),
141 number<FlatmmPipeline::GetVectorSizeB()>{},
142 number<1>{});
143 }();
144
145 const auto& ds_tensor_view = generate_tuple(
146 [&](auto i) {
147 using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
148 using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
149 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
150 {
152 static_cast<const DDataType_*>(ds_ptr[i]),
153 make_tuple(kargs.M, kargs.N),
154 make_tuple(kargs.stride_Ds[i], 1),
155 number<EpiloguePipeline::GetVectorSizeD(i)>{},
156 number<1>{});
157 }
158 else
159 {
161 static_cast<const DDataType_*>(ds_ptr[i]),
162 make_tuple(kargs.N, kargs.M),
163 make_tuple(kargs.stride_Ds[i], 1),
164 number<EpiloguePipeline::GetVectorSizeD(i)>{},
165 number<1>{});
166 }
167 },
169
170 // TODO: enable vector write for C in ColMajor
171 const auto& e_tensor_view = [&]() {
172 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
173 {
175 e_ptr,
176 make_tuple(kargs.M, kargs.N),
177 make_tuple(kargs.stride_E, 1),
178 number<EpiloguePipeline::GetVectorSizeC()>{},
179 number<1>{});
180 }
181 else
182 {
184 e_ptr,
185 make_tuple(kargs.N, kargs.M),
186 make_tuple(kargs.stride_E, 1),
187 number<1>{},
188 number<1>{});
189 }
190 }();
191
192 auto scale_n = kargs.scale_n_ptr;
193
194 index_t FlatScaleK =
195 (kargs.K / decltype(scale_n)::GranularityK) * N_Pack * BlockGemmShape::WarpTile::at(I1);
196 index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
197
198 const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
199 reinterpret_cast<const e8m0_t*>(scale_n.ptr),
200 make_tuple(FlatScaleN, FlatScaleK),
201 make_tuple(FlatScaleK, 1),
202 number<8>{},
203 number<1>{});
204
205 return make_tuple(
206 a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view, scale_b_flat_view);
207 }
208
209 template <typename TensorView>
210 CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
211 {
212 const auto& a_pad_view = [&]() {
213 const auto& a_tensor_view = views.at(I0);
214 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
215 {
216 return pad_tensor_view(a_tensor_view,
220 }
221 else
222 {
223 return pad_tensor_view(a_tensor_view,
227 }
228 }();
229
230 const auto& b_flat_tensor_view = views.at(I1);
231
232 const auto& ds_pad_view = generate_tuple(
233 [&](auto i) {
234 const auto& d_tensor_view = views.at(I2);
235 using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
236 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
237 {
238 return pad_tensor_view(d_tensor_view[i],
242 }
243 else
244 {
245 return pad_tensor_view(d_tensor_view[i],
249 }
250 },
252
253 // TODO vector write in for C in ColMajor
254 const auto& e_pad_view = [&]() {
255 const auto& e_tensor_view = views.at(I3);
256 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
257 {
258 return pad_tensor_view(e_tensor_view,
262 }
263 else
264 {
265 return pad_tensor_view(e_tensor_view,
269 }
270 }();
271
272 return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view, views.at(I4));
273 }
274
275 template <typename PadView>
276 CK_TILE_DEVICE static auto
277 MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
278 {
279 const auto& a_pad_view = views.at(I0);
280 const auto& b_flat_pad_view = views.at(I1);
281 const auto& ds_pad_view = views.at(I2);
282 const auto& e_pad_view = views.at(I3);
283
284 const auto& a_block_window = [&]() {
285 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
286 {
287 return make_tile_window(a_pad_view,
290 {i_m, 0});
291 }
292 else
293 {
294 return make_tile_window(a_pad_view,
297 {0, i_m});
298 }
299 }();
300
301 const auto& b_flat_block_window =
302 make_tile_window(b_flat_pad_view,
305 {static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
306
307 const auto ds_block_window = generate_tuple(
308 [&](auto i) {
309 using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
310 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
311 {
312 return make_tile_window(ds_pad_view[i],
315 {i_m, i_n});
316 }
317 else
318 {
319 return make_tile_window(ds_pad_view[i],
322 {i_n, i_m});
323 }
324 },
326
327 auto e_block_window = make_tile_window(
328 e_pad_view,
330 {i_m, i_n});
331
332 auto scale_block_window =
333 make_tile_window(views.at(I4),
335 number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / 32>{}),
336 {i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
337
338 return make_tuple(a_block_window,
339 b_flat_block_window,
340 ds_block_window,
341 e_block_window,
342 scale_block_window);
343 }
344
345 template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
346 CK_TILE_DEVICE static void
347 RunFlatmm(const ADataType* a_ptr,
348 const BDataType* b_flat_ptr,
349 const std::array<const void*, NumDTensor>& ds_ptr,
350 EDataType* e_ptr,
351 void* smem_ptr_ping,
352 void* smem_ptr_pong,
353 const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
354 const SplitKBatchOffset& splitk_batch_offset,
355 const index_t block_idx_m,
356 const index_t block_idx_n)
357 {
358 // Create Gemm tensor views, pad views and tile windows
359 const auto& gemm_tensor_views_tuple =
361 a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
362 const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
363 auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
364
365 const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
366
367 // Run GEMM cooperatively by whole workgroup.
368 const auto& a_block_window = gemm_tile_windows.at(I0);
369 const auto& b_flat_block_window = gemm_tile_windows.at(I1);
370 const auto& d_block_window = gemm_tile_windows.at(I2);
371 const auto& scale_block_window = gemm_tile_windows.at(I4);
372
373 static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
374 || ScaleM::GranularityMN == -1 // or ScaleA is disable
375 || ScaleN::GranularityMN == -1, // or ScaleB is disable
376 "ScaleM and ScaleN should have the same GranularityK");
377 constexpr bool DoEpiScale =
378 (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
379 (ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
380
381 auto a_block_window_with_distr =
382 ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
383 a_block_window.get_window_lengths(),
384 a_block_window.get_window_origin(),
385 FlatmmPipeline::GetADramTileDistribution());
386 const auto& c_block_tile = FlatmmPipeline{}(a_block_window_with_distr,
387 b_flat_block_window,
388 scale_block_window,
389 num_loop,
390 smem_ptr_ping,
391 smem_ptr_pong);
392
393 // Run Epilogue Pipeline
394 if constexpr(DoEpiScale)
395 {
396 auto& c_block_window = gemm_tile_windows.at(I3);
397 EpiloguePipeline{}(c_block_window,
398 c_block_tile,
399 d_block_window,
400 smem_ptr_ping,
401 kargs.scale_m_ptr + block_idx_m,
402 kargs.scale_n_ptr + block_idx_n);
403 }
404 else if(UseDefaultScheduler || (get_warp_id() == 0))
405 {
406 // Run Epilogue Pipeline
407 auto& c_block_window = gemm_tile_windows.at(I3);
408 EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
409 }
410 }
411
412 template <class ScaleM, class ScaleN>
413 CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
414 int partition_idx = blockIdx.x) const
415 {
416 int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
417
418 do
419 {
420 const auto [iM, iN] =
421 TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
422 const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
423 const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
424
425 const SplitKBatchOffset splitk_batch_offset(kargs);
426 // options
427 const ADataType* a_ptr =
428 static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
429 const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
430 splitk_batch_offset.b_k_split_offset / QuantPackedSize;
431 EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
432
433 // allocate LDS
434 __shared__ char smem_ptr_ping[Underlying::GetSmemPingSize()];
435 __shared__ char smem_ptr_pong[Underlying::GetSmemPongSize()];
436
437 if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
438 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
440 {
441 constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
443 b_flat_ptr,
444 kargs.ds_ptr,
445 e_ptr,
446 smem_ptr_ping,
447 smem_ptr_pong,
448 kargs,
449 splitk_batch_offset,
450 i_m,
451 i_n);
452 }
453 partition_idx += gridDim.x;
454 } while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
455 }
456};
457
458} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
__global__ void kentry(Args... args)
Definition tile/host/kernel_launch.hpp:22
@ atomic_add
Definition arch.hpp:58
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
std::string gemm_prec_str()
Definition utils.hpp:31
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
e8m0_bexp_t e8m0_t
Definition tile/core/numeric/e8m0.hpp:49
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition mixed_prec_flatmm_kernel.hpp:18
static constexpr int N_Pack
Definition mixed_prec_flatmm_kernel.hpp:40
static constexpr auto I4
Definition mixed_prec_flatmm_kernel.hpp:48
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition mixed_prec_flatmm_kernel.hpp:26
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition mixed_prec_flatmm_kernel.hpp:29
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition mixed_prec_flatmm_kernel.hpp:22
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition mixed_prec_flatmm_kernel.hpp:23
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition mixed_prec_flatmm_kernel.hpp:30
static constexpr index_t KernelBlockSize
Definition mixed_prec_flatmm_kernel.hpp:31
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition mixed_prec_flatmm_kernel.hpp:28
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition mixed_prec_flatmm_kernel.hpp:37
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition mixed_prec_flatmm_kernel.hpp:277
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition mixed_prec_flatmm_kernel.hpp:21
static CK_TILE_HOST const std::string GetName()
Definition mixed_prec_flatmm_kernel.hpp:54
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition mixed_prec_flatmm_kernel.hpp:35
static constexpr auto I0
Definition mixed_prec_flatmm_kernel.hpp:44
static constexpr auto I1
Definition mixed_prec_flatmm_kernel.hpp:45
static constexpr auto I2
Definition mixed_prec_flatmm_kernel.hpp:46
CK_TILE_DEVICE void operator()(FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> kargs, int partition_idx=blockIdx.x) const
Definition mixed_prec_flatmm_kernel.hpp:413
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition mixed_prec_flatmm_kernel.hpp:25
static constexpr int QuantPackedSize
Definition mixed_prec_flatmm_kernel.hpp:39
static constexpr bool UsePersistentKernel
Definition mixed_prec_flatmm_kernel.hpp:32
static CK_TILE_HOST constexpr auto GridSize(const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs)
Definition mixed_prec_flatmm_kernel.hpp:63
typename Underlying::SplitKBatchOffset SplitKBatchOffset
Definition mixed_prec_flatmm_kernel.hpp:101
FlatmmKernel< TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_ > Underlying
Definition mixed_prec_flatmm_kernel.hpp:19
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition mixed_prec_flatmm_kernel.hpp:34
static constexpr auto I3
Definition mixed_prec_flatmm_kernel.hpp:47
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_ping, void *smem_ptr_pong, const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition mixed_prec_flatmm_kernel.hpp:347
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition mixed_prec_flatmm_kernel.hpp:105
static constexpr index_t NumDTensor
Definition mixed_prec_flatmm_kernel.hpp:42
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition mixed_prec_flatmm_kernel.hpp:27
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition mixed_prec_flatmm_kernel.hpp:210
Definition flatmm_kernel.hpp:362
Definition flatmm_kernel.hpp:229
Definition flatmm_kernel.hpp:249
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemPongSize()
Definition flatmm_kernel.hpp:356
static CK_TILE_HOST constexpr auto BlockSize()
Definition flatmm_kernel.hpp:330
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemPingSize()
Definition flatmm_kernel.hpp:352
Definition type_traits.hpp:115
static constexpr int PackedSize
Definition tile/core/numeric/numeric.hpp:82
Definition tile/core/container/sequence.hpp:49