gridwise_sparse_embeddings_forward_layernorm_builtins.hpp Source File

gridwise_sparse_embeddings_forward_layernorm_builtins.hpp Source File#

Composable Kernel: gridwise_sparse_embeddings_forward_layernorm_builtins.hpp Source File
gridwise_sparse_embeddings_forward_layernorm_builtins.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
10
11namespace ck {
12
13template <typename GridwiseSparseEmbedding,
14 typename EmbType,
15 typename IndexType,
16 typename GammaDataType,
17 typename BetaDataType,
18 typename AccDataType,
19 typename OutType,
20 typename OutGridDesc,
21 typename EmbElementwiseOperation,
22 ck::index_t NumEmbeddings>
23#if CK_USE_LAUNCH_BOUNDS
25#endif
27 OutType* p_out,
28 const ck::Array<EmbType*, NumEmbeddings> p_embs,
29 const ck::Array<IndexType*, NumEmbeddings> p_indexes,
30 const GammaDataType* p_gamma,
31 const BetaDataType* p_beta,
32 const OutGridDesc out_grid_desc,
33 const AccDataType epsilon,
34 const EmbElementwiseOperation emb_elementwise_op)
35{
36 GridwiseSparseEmbedding::Run(
37 p_out, p_embs, p_indexes, p_gamma, p_beta, out_grid_desc, epsilon, emb_elementwise_op);
38}
39
40template <typename EmbType,
41 typename IndexType,
42 typename GammaDataType,
43 typename BetaDataType,
44 typename AccDataType,
45 typename OutType,
46 typename OutGridDesc,
47 typename EmbElementwiseOperation,
48 ck::index_t BlockSize,
49 ck::index_t DimClusterSize,
50 ck::index_t RowClusterSize,
51 ck::index_t DimPerBlock, // Row x Dim, along Dim
52 ck::index_t RowPerBlock, // Row x Dim, along Row
53 ck::index_t DimThreadSize, // this is actually not vector, but number of registers
54 ck::index_t RowVectorSize,
55 ck::index_t NumEmbeddings>
57{
58 static constexpr auto I0 = Number<0>{};
59 static constexpr auto I1 = Number<1>{};
60 static constexpr auto I2 = Number<2>{};
61 static constexpr auto I3 = Number<3>{};
62 static constexpr index_t WaveSize = 64;
63
64 static_assert(BlockSize == RowClusterSize * DimClusterSize,
65 "Invalid cluster distribution within block");
66 static_assert(RowClusterSize % WaveSize == 0, "need to be wavewise");
67
68 static_assert(DimPerBlock % (DimClusterSize * DimThreadSize) == 0, "");
69 static_assert(RowPerBlock % (RowClusterSize * RowVectorSize) == 0, "");
70
71 static constexpr auto DimSubBlocks = DimPerBlock / (DimClusterSize * DimThreadSize);
72 static constexpr auto RowSubBlocks = RowPerBlock / (RowClusterSize * RowVectorSize);
73
74 static_assert((DimPerBlock % DimSubBlocks == 0) && (RowPerBlock % RowSubBlocks == 0), "");
75 static constexpr auto DimPerSubBlock = DimPerBlock / DimSubBlocks;
76 static constexpr auto RowPerSubBlock = RowPerBlock / RowSubBlocks;
77
80
83
86
88
91
92 __device__ static void Run(OutType* p_out,
95 const GammaDataType* p_gamma,
96 const BetaDataType* p_beta,
97 const OutGridDesc,
98 const AccDataType epsilon,
99 const EmbElementwiseOperation emb_elementwise_op)
100 {
101 const index_t thread_local_id = get_thread_local_1d_id();
102 const index_t block_global_id = get_block_1d_id();
103
104 constexpr auto thread_cluster_desc =
106
107 const auto thread_cluster_idx =
108 thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
109
110 const auto thread_dim_cluster_id = thread_cluster_idx[I0];
111 const auto thread_row_cluster_id = thread_cluster_idx[I1];
112
113 const auto wave_dim_id = __builtin_amdgcn_readfirstlane(thread_dim_cluster_id / WaveSize);
114
115 const auto index_start = block_global_id * DimPerBlock + wave_dim_id * DimThreadSize;
116
117 auto threadwise_welford = ThreadwiseWelford();
118 threadwise_welford.max_count_ = RowSubBlocks * RowVectorSize;
119
120 constexpr auto thread_buf_size =
121 DimSubBlocks * DimThreadSize * RowSubBlocks * RowVectorSize;
122 constexpr auto thread_buf_desc = make_naive_tensor_descriptor_packed(
123 make_tuple(DimSubBlocks, DimThreadSize, RowSubBlocks, RowVectorSize));
124 constexpr auto mean_var_buf_size = DimSubBlocks * DimThreadSize;
125 constexpr auto mean_var_buf_desc =
127 constexpr auto gamma_beta_buf_size = RowSubBlocks * RowVectorSize;
128 constexpr auto gamma_beta_buf_desc =
130
132 NumEmbeddings>
133 in_thread_bufs;
135 index_bufs;
136
138
140 gamma_thread_buf;
142 beta_thread_buf;
143
146
147 auto load_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
149 auto emb_a = emb_vectors[0];
150 using src_vector_t = typename decltype(emb_a)::type;
151 static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
152 constexpr auto current_dim = i_dim_sub_ * DimPerSubBlock + i_dim_vec_;
153
154 auto thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
155 sizeof(EmbType) * RowVectorSize;
156 static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
157 IndexType index = index_bufs[i_embedding_][Number<current_dim>{}];
158
159 __amdgpu_buffer_rsrc_t emb_res =
161 index * RowPerBlock);
162 emb_vectors(i_embedding_).template AsType<src_vector_t>()(I0) =
163 amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res, thread_offset, 0);
164 });
165
166 static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
167 constexpr auto register_offset = thread_buf_desc.CalculateOffset(
168 make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
169 static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
170 in_thread_bufs(i_embedding_)(Number<register_offset>{}) =
172 emb_vectors[i_embedding_].template AsType<EmbType>()[i_row_vec_]);
173 });
174 });
175 });
176 };
177
178 auto accumulate_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
179 static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
180 static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
181 constexpr auto register_offset = thread_buf_desc.CalculateOffset(
182 make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
183 auto in_data_refs = generate_tie(
184 [&](auto i_embedding_) -> const auto& {
185 return in_thread_bufs(i_embedding_)(Number<register_offset>{});
186 },
188 auto out_data_refs = generate_tie(
189 [&](auto) -> auto& { return acc_thread_buf(Number<register_offset>{}); },
190 Number<1>{});
191 unpack2(emb_elementwise_op, out_data_refs, in_data_refs);
192 });
193 });
194 };
195
196 auto threadwise_welford_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
197 static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
198 static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
199 constexpr auto register_offset = thread_buf_desc.CalculateOffset(
200 make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
201 constexpr auto mean_var_offset =
202 mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
203
204 threadwise_welford.cur_count_++;
205 threadwise_welford.Update(mean_thread_buf(Number<mean_var_offset>{}),
206 var_thread_buf(Number<mean_var_offset>{}),
207 acc_thread_buf(Number<register_offset>{}));
208 });
209 });
210 };
211
212 auto threadwise_normalize_store_out = [&](auto i_dim_sub_, auto i_row_sub_) {
213 __amdgpu_buffer_rsrc_t out_res =
214 make_wave_buffer_resource_with_default_range_new(p_out + index_start * RowPerBlock);
215 static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
217 using dst_vector_t = typename decltype(out_vector)::type;
218
219 constexpr auto mean_var_offset =
220 mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
221 auto divisor =
222 1 / __builtin_amdgcn_sqrtf(var_thread_buf(Number<mean_var_offset>{}) + epsilon);
223 static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
224 constexpr auto register_offset = thread_buf_desc.CalculateOffset(
225 make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
226 constexpr auto gamma_beta_offset =
227 gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_));
228
229 auto acc_val = acc_thread_buf[Number<register_offset>{}];
230 acc_val = (acc_val - mean_thread_buf(Number<mean_var_offset>{})) * divisor;
231 acc_val = acc_val * gamma_thread_buf[Number<gamma_beta_offset>{}] +
232 beta_thread_buf[Number<gamma_beta_offset>{}];
233
234 out_vector.template AsType<OutType>()(Number<i_row_vec_>{}) =
235 type_convert<OutType>(acc_val);
236 });
237
238 index_t thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
239 sizeof(OutType) * RowVectorSize;
240
242 out_vector.template AsType<dst_vector_t>()[Number<0>{}],
243 out_res,
244 thread_offset,
245 0);
246 });
247 };
248
249 // first load index
250 ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) {
251 // prefer use s_load
252 ck::static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
253 index_bufs(i_embedding_)(i_idx_) =
254 p_indexes[i_embedding_][index_start + i_idx_.value];
255 });
256 });
257
258 // load gamma/beta
259 static_for<0, RowSubBlocks, 1>{}([&](auto i_row_sub_) {
262
263 index_t thread_offset_gamma = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
264 sizeof(GammaDataType) * RowVectorSize;
265 index_t thread_offset_beta = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
266 sizeof(BetaDataType) * RowVectorSize;
267
268 __amdgpu_buffer_rsrc_t gamma_res =
270 __amdgpu_buffer_rsrc_t beta_res =
272
273 gamma_vector.template AsType<typename decltype(gamma_vector)::type>()(I0) =
275 gamma_res, thread_offset_gamma, 0);
276 beta_vector.template AsType<typename decltype(beta_vector)::type>()(I0) =
277 amd_buffer_load_impl<BetaDataType, RowVectorSize>(beta_res, thread_offset_beta, 0);
278
279 static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
280 constexpr auto offset =
281 gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_));
282 gamma_thread_buf(Number<offset>{}) = type_convert<AccDataType>(
283 gamma_vector.template AsType<GammaDataType>()[Number<i_row_vec_>{}]);
284 beta_thread_buf(Number<offset>{}) = type_convert<AccDataType>(
285 beta_vector.template AsType<BetaDataType>()[Number<i_row_vec_>{}]);
286 });
287 });
288
290 [&](auto I) { acc_thread_buf(I) = type_convert<AccDataType>(0.0f); });
291
293 mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
294 var_thread_buf(I) = type_convert<AccDataType>(0.0f);
295 });
296
297 static_for<0, DimSubBlocks, 1>{}([&](auto i_dim_sub) {
298 load_current_sub_row(i_dim_sub, Number<0>{});
299 static_for<0, RowSubBlocks - 1, 1>{}([&](auto i_row) {
300 load_current_sub_row(i_dim_sub, Number<1>{} + i_row);
301 accumulate_current_sub_row(i_dim_sub, i_row);
302 threadwise_welford_sub_row(i_dim_sub, i_row);
303 });
304 accumulate_current_sub_row(i_dim_sub, Number<RowSubBlocks - 1>{});
305 threadwise_welford_sub_row(i_dim_sub, Number<RowSubBlocks - 1>{});
306
307 // blockwise welford
309 if constexpr(I > 0)
312 mean_thread_buf(I), var_thread_buf(I), threadwise_welford.cur_count_);
313 });
314
315 // store
317 [&](auto i_row) { threadwise_normalize_store_out(i_dim_sub, i_row); });
318 });
319 }
320};
321
322} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
__device__ void amd_buffer_store_impl(const typename vector_type< T, N >::type src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset)
Definition utility/amd_buffer_addressing.hpp:544
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_sparse_embeddings_forward_layernorm(OutType *p_out, const ck::Array< EmbType *, NumEmbeddings > p_embs, const ck::Array< IndexType *, NumEmbeddings > p_indexes, const GammaDataType *p_gamma, const BetaDataType *p_beta, const OutGridDesc out_grid_desc, const AccDataType epsilon, const EmbElementwiseOperation emb_elementwise_op)
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:26
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ vector_type< T, N >::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset)
Definition utility/amd_buffer_addressing.hpp:419
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition functional4.hpp:55
__device__ __amdgpu_buffer_rsrc_t make_wave_buffer_resource_with_default_range_new(T *p_wave)
Definition utility/amd_buffer_addressing_builtins.hpp:66
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
typename vector_type_maker< T, N >::type vector_type_maker_t
Definition dtype_vector.hpp:54
Definition utility/array.hpp:14
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition blockwise_welford.hpp:51
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:57
static constexpr auto I0
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:58
static constexpr auto RowPerSubBlock
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:76
static constexpr auto RowSubBlocks
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:72
static constexpr auto DimSubBlocks
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:71
static constexpr auto DimPerSubBlock
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:75
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< DimSubBlocks *DimThreadSize >{}, Number< RowSubBlocks *RowVectorSize >{}))) ThreadwiseWolfordDesc2D
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:78
static constexpr auto I1
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:59
static constexpr auto I2
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:60
static __device__ void Run(OutType *p_out, const ck::Array< EmbType *, NumEmbeddings > p_embs, const ck::Array< IndexType *, NumEmbeddings > p_indexes, const GammaDataType *p_gamma, const BetaDataType *p_beta, const OutGridDesc, const AccDataType epsilon, const EmbElementwiseOperation emb_elementwise_op)
Definition gridwise_sparse_embeddings_forward_layernorm_builtins.hpp:92
static constexpr index_t WaveSize
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:62
static constexpr auto I3
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:61
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition functional2.hpp:33