block_wp_asmem_bsmem_creg_v1.hpp Source File

block_wp_asmem_bsmem_creg_v1.hpp Source File#

Composable Kernel: block_wp_asmem_bsmem_creg_v1.hpp Source File
block_wp_asmem_bsmem_creg_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// A is block window on shared memory
12// B is block window on shared memory
13// C is block distributed tensor
14template <typename Problem_, typename BlockPolicy_>
16{
23
24 static constexpr auto I0 = number<0>();
25 static constexpr auto I1 = number<1>();
26 static constexpr auto I2 = number<2>();
27 static constexpr auto idxM = I0;
28 static constexpr auto idxN = I1;
29 static constexpr auto idxK = I2;
33
34 static constexpr index_t kBlockSize = Problem::kBlockSize;
35
36 static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
37 using WG = remove_cvref_t<decltype(config.template at<0>())>;
38
39 CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
40 {
41 constexpr index_t MPerBlock = BlockGemmShape::kM;
42 constexpr index_t NPerBlock = BlockGemmShape::kN;
43
44 constexpr index_t MWarp = config.template at<1>();
45 constexpr index_t NWarp = config.template at<2>();
46
47 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
48 constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
49
50 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
57
58 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
59 c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
60
61 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
62
63 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
64 return c_block_tensor;
65 }
66
67 // C += A * B
68 template <typename CBlockTensor, typename ABlockWindow, typename BFlatBlockTensor>
69 CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
70 ABlockWindow& a_warp_windows,
71 BFlatBlockTensor& b_warp_tensor) const
72 {
73 constexpr index_t MPerBlock = BlockGemmShape::kM;
74 constexpr index_t KPerBlock = BlockGemmShape::kK;
75
76 constexpr index_t MWarp = config.template at<1>();
77
78 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
79 constexpr index_t NIterPerWarp =
80 BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN));
81 constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
82
83 using CWarpDstr = typename WG::CWarpDstr;
84 using CWarpTensor = typename WG::CWarpTensor;
85
86 constexpr auto c_warp_y_lengths =
87 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
88 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
89
90 // hot loop:
91 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
92 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
93 // read A warp tensor from A block window
94 const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
95
96 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
97 // read C warp tensor from C block tensor
98 CWarpTensor c_warp_tensor;
99
100 c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
101 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
102 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
103
104 // warp GEMM
105 WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter));
106
107 // write C warp tensor into C block tensor
108 c_block_tensor.set_y_sliced_thread_data(
109 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
110 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
111 c_warp_tensor.get_thread_buffer());
112 });
113 });
114 });
115 }
116};
117
118} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
Definition block_wp_asmem_bsmem_creg_v1.hpp:16
static constexpr auto I0
Definition block_wp_asmem_bsmem_creg_v1.hpp:24
static constexpr auto config
Definition block_wp_asmem_bsmem_creg_v1.hpp:36
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
Definition block_wp_asmem_bsmem_creg_v1.hpp:39
static constexpr index_t kBlockSize
Definition block_wp_asmem_bsmem_creg_v1.hpp:34
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition block_wp_asmem_bsmem_creg_v1.hpp:22
remove_cvref_t< typename Problem::ADataType > ADataType
Definition block_wp_asmem_bsmem_creg_v1.hpp:19
static constexpr auto idxK
Definition block_wp_asmem_bsmem_creg_v1.hpp:29
remove_cvref_t< typename Problem::CDataType > CDataType
Definition block_wp_asmem_bsmem_creg_v1.hpp:21
static constexpr auto idxM
Definition block_wp_asmem_bsmem_creg_v1.hpp:27
remove_cvref_t< Problem_ > Problem
Definition block_wp_asmem_bsmem_creg_v1.hpp:17
remove_cvref_t< decltype(config.template at< 0 >())> WG
Definition block_wp_asmem_bsmem_creg_v1.hpp:37
static constexpr auto I2
Definition block_wp_asmem_bsmem_creg_v1.hpp:26
remove_cvref_t< typename BlockGemmShape::WarpTile > WarpTile
Definition block_wp_asmem_bsmem_creg_v1.hpp:32
remove_cvref_t< typename Problem::BDataType > BDataType
Definition block_wp_asmem_bsmem_creg_v1.hpp:20
static constexpr auto idxN
Definition block_wp_asmem_bsmem_creg_v1.hpp:28
remove_cvref_t< BlockPolicy_ > BlockPolicy
Definition block_wp_asmem_bsmem_creg_v1.hpp:18
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, ABlockWindow &a_warp_windows, BFlatBlockTensor &b_warp_tensor) const
Definition block_wp_asmem_bsmem_creg_v1.hpp:69
remove_cvref_t< typename BlockGemmShape::BlockWarps > BlockWarps
Definition block_wp_asmem_bsmem_creg_v1.hpp:31
static constexpr auto I1
Definition block_wp_asmem_bsmem_creg_v1.hpp:25
remove_cvref_t< typename BlockGemmShape::BlockTile > BlockTile
Definition block_wp_asmem_bsmem_creg_v1.hpp:30
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192