22 template <
typename Problem>
25 constexpr index_t MaxVectorSize = 16 /
sizeof(
typename Problem::QDataType);
28 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
31 return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
34 template <
typename Problem>
39#if defined(__gfx950__)
40 constexpr index_t MaxReadSizeInBytes = 16;
42 constexpr index_t MaxReadSizeInBytes = 4;
44 return MaxReadSizeInBytes /
sizeof(KDataType);
47 template <
typename Problem>
52#if defined(__gfx950__)
53 constexpr index_t MaxReadSizeInBytes = 16;
55 constexpr index_t MaxReadSizeInBytes = 4;
57 return MaxReadSizeInBytes /
sizeof(VDataType);
60 template <
typename Problem>
64 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
67 return WG::WarpGemmAttribute::Impl::kCM1PerLane;
70 template <
typename Problem>
77 return 16 /
sizeof(KDataType);
80 template <
typename Problem>
87 return 16 /
sizeof(VDataType);
90 template <
typename Problem>
95 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
96 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
97 constexpr index_t kBlockSize = Problem::kBlockSize;
98 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
103 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
104 constexpr index_t LanesPerK = kKPerBlock / KVector;
105 constexpr index_t LaneGroups = WarpSize / LanesPerK;
106 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
107 static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
109 constexpr index_t N0 = NumIssues;
110 constexpr index_t N1 = LaneGroups;
111 constexpr index_t N2 = NumWarps;
112 constexpr index_t K0 = LanesPerK;
113 constexpr index_t K1 = KVector;
124 template <
typename Problem>
129 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
130 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN1;
131 constexpr index_t kBlockSize = Problem::kBlockSize;
132 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
137 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
138 constexpr index_t LanesPerK = kKPerBlock / KVector;
139 constexpr index_t LaneGroups = WarpSize / LanesPerK;
140 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
141 static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
143 constexpr index_t N0 = NumIssues;
144 constexpr index_t N1 = LaneGroups;
145 constexpr index_t N2 = NumWarps;
146 constexpr index_t K0 = LanesPerK;
147 constexpr index_t K1 = KVector;
158 template <
typename Problem>
168 template <
typename Problem>
178 template <
typename Problem>
188 template <
typename Problem>
194 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
195 using WarpGemm =
remove_cvref_t<
decltype(config.template at<0>())>;
197 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(
number<0>{});
198 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(
number<1>{});
200 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
201 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
203 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
204 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
206 constexpr auto v_block_outer_dstr_encoding =
215 v_block_outer_dstr_encoding,
typename WarpGemm::BWarpDstrEncoding{});
218 constexpr auto v_block_dstr =
220 decltype(v_block_dstr_encode),
221 typename Problem::VDataType>::TransposedDstrEncode{});
226 template <
typename Problem>
233 typename Problem::KDataType,
234 typename Problem::SaccDataType,
237 Problem::BlockFmhaShape::kN0,
238 Problem::BlockFmhaShape::kK0>,
239 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
240 typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
242 constexpr auto warp_gemm = []() {
243 if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
244 std::is_same_v<typename Problem::KDataType, half_t> &&
245 std::is_same_v<typename Problem::SaccDataType, float>)
251 else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
252 std::is_same_v<typename Problem::KDataType, bf16_t> &&
253 std::is_same_v<typename Problem::SaccDataType, float>)
261 using BlockGemmPolicy =
263 typename Problem::KDataType,
264 typename Problem::SaccDataType,
265 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
272 template <
typename Problem>
279 typename Problem::VDataType,
280 typename Problem::OaccDataType,
283 Problem::BlockFmhaShape::kN1,
284 Problem::BlockFmhaShape::kK1>,
285 typename Problem::BlockFmhaShape::Gemm1BlockWarps,
286 typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
290 typename Problem::VDataType,
291 typename Problem::OaccDataType,
292 Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<0>{}),
293 Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<1>{}),
294 Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<2>{}),
300 using BlockGemmPolicy =
302 typename Problem::VDataType,
303 typename Problem::OaccDataType,
304 typename Problem::BlockFmhaShape::Gemm1BlockWarps,
313 template <
typename Problem, ck_tile::index_t IBuf = 0>
320 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
321 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
322 constexpr index_t kBlockSize = Problem::kBlockSize;
323 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
330 sizeof(
typename Problem::KDataType);
333 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
335 kKPerBlock / KVector;
339 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
340 static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
365 make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
366 make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
368 return k_lds_block_desc_issues_warps_lanes;
371 template <
typename Problem>
377 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
378 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
379 constexpr index_t kBlockSize = Problem::kBlockSize;
380 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
387 sizeof(
typename Problem::KDataType);
389 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
390 constexpr index_t LanesPerK = kKPerBlock / KVector;
391 constexpr index_t LaneGroups = WarpSize / LanesPerK;
392 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
393 static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
395 constexpr auto k_lds_block_desc_0 =
399 number<kKPerBlock / KPack>{},
418 return k_lds_block_desc;
421 template <
typename Problem>
425 constexpr index_t SingleKSize = [&]() {
426 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
427 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
428 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
433 constexpr index_t kPad = KPack;
435 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
436 constexpr index_t LanesPerK = kKPerBlock / KVector;
437 constexpr index_t LaneGroups = WarpSize / LanesPerK;
438 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
440 return NumIssues * NumWarps * (WarpSize * KVector + kPad);
443 constexpr index_t SingleVSize = [&]() {
445 constexpr index_t Banks = get_n_lds_banks();
446 constexpr index_t PixelsPerRow = Banks * 4 /
sizeof(VDataType);
448 static_assert(PixelsPerRow % kKPack == 0);
449 constexpr index_t NPerRow = PixelsPerRow / kKPack;
450 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
451 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
452 static_assert(kNPerBlock % NPerRow == 0);
453 static_assert(kKPerBlock % kKPack == 0);
455 return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
458 return max(SingleKSize, SingleVSize);
461 template <
typename Problem, ck_tile::index_t IBuf = 0>
468 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
469 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN1;
470 constexpr index_t kBlockSize = Problem::kBlockSize;
471 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
478 sizeof(
typename Problem::VDataType);
481 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
483 kKPerBlock / KVector;
487 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
488 static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
513 make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
514 make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
516 return v_lds_block_desc_issues_warps_lanes;
519 template <
typename Problem>
525 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
526 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN1;
527 constexpr index_t kBlockSize = Problem::kBlockSize;
528 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
535 sizeof(
typename Problem::VDataType);
537 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
538 constexpr index_t LanesPerK = kKPerBlock / KVector;
539 constexpr index_t LaneGroups = WarpSize / LanesPerK;
540 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
541 static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
543 constexpr auto v_lds_block_desc_0 =
547 number<kKPerBlock / KPack>{},
566 return v_lds_block_desc;
569 template <
typename Problem>
576 constexpr index_t k_element_space_size =
581 constexpr index_t v_element_space_size =
584 static_assert(
ck_tile::max(k_element_space_size, v_element_space_size) <=
589 static_assert(std::is_same_v<typename Problem::KDataType, typename Problem::VDataType>);
590 constexpr index_t kv_element_space_size_in_bytes =
593 return kv_element_space_size_in_bytes;
596 template <
typename Problem>
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
@ Double
Definition warp_gemm_attribute_mfma.hpp:15
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
TransposeTileDistributionTraits< TileDistributionEncoding_, DataType_, Policy, true > InputTileDistributionTraits
Definition load_tile_transpose.hpp:343
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_with_offset(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, const offset &os, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:319
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
WarpGemmImpl< WarpGemmAttributeMfmaIterateKAndTransposedCDistribution< WarpGemmAttributeMfmaImplF16F16F32M32N32K8< WGAttrCtlEnum::Default_ >, 2, AttrNumAccess > > WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
Definition warp_gemm.hpp:91
WarpGemmImpl< WarpGemmAttributeMfmaIterateKAndTransposedCDistribution< WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8< WGAttrCtlEnum::Default_ >, 2, AttrNumAccess > > WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
Definition warp_gemm.hpp:213
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
@ MNK
Definition block_gemm_areg_breg_creg_v2_custom_policy.hpp:13
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:15
static CK_TILE_DEVICE constexpr auto MakePRegTileDistribution()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:179
static CK_TILE_DEVICE constexpr auto MakeKLdsStoreBlockDescriptor(ck_tile::number< IBuf >=ck_tile::number< 0 >{})
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:315
static CK_TILE_DEVICE constexpr auto MakeVRegTileDistribution()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:189
static constexpr ck_tile::index_t kKLdsPadInBytes
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:310
static CK_TILE_DEVICE constexpr auto MakeQRegTileDistribution()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:159
static constexpr ck_tile::index_t NumThreadPerWarpGroup
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:17
static CK_TILE_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:597
static constexpr ck_tile::index_t NumWarpPerGroup
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:16
static CK_TILE_DEVICE constexpr auto MakeKLdsLoadBlockDescriptor()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:372
static CK_TILE_DEVICE constexpr ck_tile::index_t GetSmemSizeKV()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:570
static CK_TILE_HOST_DEVICE constexpr auto GetSmemVPackK()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:81
static CK_TILE_DEVICE constexpr auto GetSingleSmemElementSpaceSize()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:422
static CK_TILE_DEVICE constexpr auto MakeKRegTileDistribution()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:169
static CK_TILE_DEVICE constexpr auto MakeVLdsLoadBlockDescriptor()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:520
static CK_TILE_DEVICE constexpr auto GetAlignmentK()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:35
static CK_TILE_DEVICE constexpr auto GetPVBlockGemm()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:273
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackK()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:71
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentO()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:61
static CK_TILE_DEVICE constexpr auto MakeKDramTileDistribution()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:91
static CK_TILE_DEVICE constexpr auto MakeVDramTileDistribution()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:125
static CK_TILE_DEVICE constexpr auto MakeVLdsStoreBlockDescriptor(ck_tile::number< IBuf >=ck_tile::number< 0 >{})
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:463
static constexpr ck_tile::index_t kVLdsPadInBytes
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:311
static CK_TILE_DEVICE constexpr auto GetQKBlockGemm()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:227
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentQ()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:23
static CK_TILE_DEVICE constexpr auto GetAlignmentV()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:48
Definition block_gemm_areg_breg_creg_v2_custom_policy.hpp:23
Definition block_gemm_areg_breg_creg_v2.hpp:17
Definition block_gemm_problem.hpp:18
Definition tile_gemm_shape.hpp:17
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192