thread_group_tensor_slice_transfer_global.hpp Source File

thread_group_tensor_slice_transfer_global.hpp Source File#

Composable Kernel: thread_group_tensor_slice_transfer_global.hpp Source File
thread_group_tensor_slice_transfer_global.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
12
13namespace ck {
14
15template <typename SrcDesc,
16 typename DstDesc,
17 typename SrcData,
18 typename DstData,
19 typename ElementwiseOperation,
20 typename NumberOfIterations,
21 typename StepsPerIteration,
22 typename IterationOrder,
23 index_t VectorSize,
24 bool DoTranspose>
26{
27 static constexpr auto I0 = Number<0>{};
28 static constexpr auto I1 = Number<1>{};
29 static constexpr auto I2 = Number<2>{};
30 static constexpr auto I3 = Number<3>{};
31 static constexpr auto I4 = Number<4>{};
32 static constexpr auto I5 = Number<5>{};
33 static constexpr auto I6 = Number<6>{};
34
37 using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
38 using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
39
40 __device__ ThreadGroupTransferGlobal(const SrcDesc& src_desc,
41 const DstDesc& dst_desc,
42 const Index& src_block_slice_origin,
43 const Index& dst_block_slice_origin,
44 const ElementwiseOperation& element_op)
45 : src_coord_(make_tensor_coordinate(src_desc, src_block_slice_origin)),
46 dst_coord_(make_tensor_coordinate(dst_desc, dst_block_slice_origin)),
47 element_op_(element_op)
48 {
49 }
50
51 template <typename GridBufferType>
52 __device__ void RunRead(const SrcDesc& src_desc, const GridBufferType& grid_buf)
53 {
54 constexpr auto src_access_lengths = NumberOfIterations{};
55 constexpr auto src_dim_access_order = IterationOrder{};
56 constexpr auto ordered_src_access_lengths =
57 container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
58 constexpr auto ordered_fwd_step = StepsPerIteration{};
59
60 // make forward steps
61 // forward step for each iteration just add 1
62 const auto src_forward_steps = generate_tuple(
63 [&](auto i) {
64 Index forward_step_idx;
65
66 static_for<0, nDim, 1>{}([&](auto j) {
67 forward_step_idx(j) = (i.value == j.value) ? ordered_fwd_step[i] : 0;
68 });
69
70 return make_tensor_coordinate_step(src_desc, forward_step_idx);
71 },
72 Number<nDim>{});
73
74 // make backward steps
75 // backward step at the end of the dimension iteration subtract IterationLength - 1
76 const auto src_backward_steps = generate_tuple(
77 [&](auto i) {
78 Index backward_step_idx;
79
80 static_for<0, nDim, 1>{}([&](auto j) {
81 backward_step_idx(j) = (i.value == j.value)
82 ? (-src_access_lengths[i] + 1) * ordered_fwd_step[i]
83 : 0;
84 });
85
86 return make_tensor_coordinate_step(src_desc, backward_step_idx);
87 },
88 Number<nDim>{});
89
90 static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
91 // judge move forward or move backward
92 constexpr auto forward_sweep = [&]() {
94
95 // Take condition for bwd and negate
96 // condition for bwd: dimension index is the last of iteration and
97 // all dimension indices of higher dimensions (inner loops)
98 // are the last of their iteration
99 static_for<0, nDim, 1>{}([&](auto i) {
100 bool tmp = ordered_src_access_idx[i] == ordered_src_access_lengths[i] - 1;
101 static_for<i + 1, nDim, 1>{}([&](auto j) {
102 tmp &= ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
103 });
104 forward_sweep_(i) = !tmp;
105 });
106 return forward_sweep_;
107 }();
108
109 // check for each dimension, if it needs to be moved (either fwd or bwd)
110 constexpr auto move_on_dim = [&]() constexpr {
112
113 // forward condition
114 static_for<0, nDim, 1>{}([&](auto i) {
115 move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
116
117 static_for<i + 1, nDim, 1>{}([&](auto j) {
118 move_on_dim_(i) &=
119 ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
120 });
121 });
122
123 // backward condition
124 static_for<0, nDim, 1>{}([&](auto i) {
125 bool tmp = ordered_src_access_idx[i] == ordered_src_access_lengths[i] - 1 &&
126 ordered_src_access_idx[i] > 0;
127 static_for<i + 1, nDim, 1>{}([&](auto j) {
128 tmp &= ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
129 });
130 move_on_dim_(i) |= tmp;
131 });
132
133 return move_on_dim_;
134 }();
135
136 // calculate src data index and make sequence
137 constexpr auto src_data_idx = [&]() {
138 Index ordered_idx;
139
141 [&](auto i) { ordered_idx(i) = ordered_src_access_idx[i]; });
142
143 return container_reorder_given_old2new(ordered_idx, src_dim_access_order);
144 }();
145
146 // make sequence to access vgpr data. Add zero as last element of src_data_idx_seq
147 constexpr auto vgpr_data_idx_seq = generate_sequence_v2(
148 [&](auto i) {
149 if constexpr(i.value < src_data_idx.Size())
150 {
152 }
153 else
154 {
155 return Number<0>{};
156 }
157 },
158 Number<src_data_idx.Size() + 1>{});
159
160 // check if src element is valid
161 const bool is_src_valid =
163
164 // Vector length of elementwise operation
165 constexpr auto get_elem_op_vec_len = []() {
166 if constexpr(is_detected<is_pack8_invocable_t, decltype(element_op_)>::value)
167 {
168 if constexpr(decltype(element_op_)::is_pack8_invocable)
169 return math::min(8, VectorSize);
170 }
171 else if constexpr(is_detected<is_pack4_invocable_t, decltype(element_op_)>::value)
172 {
173 if constexpr(decltype(element_op_)::is_pack4_invocable)
174 return math::min(4, VectorSize);
175 }
176 else if constexpr(is_detected<is_pack2_invocable_t, decltype(element_op_)>::value)
177 {
178 if constexpr(decltype(element_op_)::is_pack2_invocable)
179 return math::min(2, VectorSize);
180 }
181 else
182 {
183 return 1;
184 }
185 };
186
187 // This is 1 for pass through because internally it's doing type conversion
188 constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
189
190 using src_vector_container = vector_type_maker_t<SrcData, VectorSize>;
191 using src_vector_container_t = typename src_vector_container::type;
192
193 using elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
194
195 using dst_vector_type = vector_type_maker_t<DstData, VectorSize>;
196 using dst_vector_t = typename dst_vector_type::type;
197
199
200 dst_vector_type op_r_v;
201
202 // Load data from memory in src_vector first
203 src_vector_container src_vector =
204 src_vector_container{grid_buf.template Get<src_vector_container_t, DoTranspose>(
205 src_coord_.GetOffset(), true)};
206
207 // apply the src elementwise op and convert to DstData under the hood if needed
208 static_for<0, VectorSize / elem_op_vec_len, 1>{}([&](auto idx) {
209 element_op_(op_r_v.template AsType<elem_op_vec_t>()(idx),
210 src_vector.template AsType<elem_op_vec_t>()[idx]);
211 });
212
213 // store result in dvgpr_ (static array holding loaded data).
214 // At this point data is already converted to DstData type and
215 // the elementwise operation has been applied
216 dvgpr_.template SetAsType<dst_vector_t>(
217 vgpr_data_idx_seq,
218 is_src_valid ? op_r_v.template AsType<dst_vector_t>()[I0] : vector_t(0));
219
220 // For each dimension move fwd, bwd or don't move
221 static_for<0, nDim, 1>{}([&](auto i) {
222 if constexpr(move_on_dim[i])
223 {
224 if constexpr(forward_sweep[i])
225 {
227 src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
228 }
229 else
230 {
232 src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
233 }
234 }
235 });
236 });
237 }
238
239 template <typename BlockBufferType>
240 __device__ void RunWrite(const DstDesc& dst_desc, BlockBufferType& dst_buf)
241 {
242 using dst_vector_type = vector_type_maker_t<DstData, VectorSize>;
243 using dst_vector_t = typename dst_vector_type::type;
244
245 constexpr auto src_access_lengths = NumberOfIterations{};
246 constexpr auto src_dim_access_order = IterationOrder{};
247 constexpr auto ordered_src_access_lengths =
248 container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
249 constexpr auto ordered_fwd_step = StepsPerIteration{};
250
251 // make forward steps
252 // forward step for each iteration just add 1
253 const auto dst_forward_steps = generate_tuple(
254 [&](auto i) {
255 Index forward_step_idx;
256
257 static_for<0, nDim, 1>{}([&](auto j) {
258 forward_step_idx(j) = (i.value == j.value) ? ordered_fwd_step[i] : 0;
259 });
260
261 return make_tensor_coordinate_step(dst_desc, forward_step_idx);
262 },
263 Number<nDim>{});
264
265 // make backward steps
266 // backward step at the end of the dimension iteration subtract IterationLength - 1
267 const auto dst_backward_steps = generate_tuple(
268 [&](auto i) {
269 Index backward_step_idx;
270
271 static_for<0, nDim, 1>{}([&](auto j) {
272 backward_step_idx(j) = (i.value == j.value)
273 ? (-src_access_lengths[i] + 1) * ordered_fwd_step[i]
274 : 0;
275 });
276
277 return make_tensor_coordinate_step(dst_desc, backward_step_idx);
278 },
279 Number<nDim>{});
280
281 static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
282 // judge move forward or move backward
283 constexpr auto forward_sweep = [&]() {
285
286 // Take condition for bwd and negate
287 // condition for bwd: dimension index is the last of iteration and
288 // all dimension indices of higher dimensions (inner loops)
289 // are the last of their iteration
290 static_for<0, nDim, 1>{}([&](auto i) {
291 bool tmp = ordered_src_access_idx[i] == ordered_src_access_lengths[i] - 1;
292 static_for<i + 1, nDim, 1>{}([&](auto j) {
293 tmp &= ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
294 });
295 forward_sweep_(i) = !tmp;
296 });
297 return forward_sweep_;
298 }();
299
300 // check for each dimension, if it needs to be moved (either fwd or bwd)
301 constexpr auto move_on_dim = [&]() constexpr {
303
304 // forward condition
305 static_for<0, nDim, 1>{}([&](auto i) {
306 move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
307
308 static_for<i + 1, nDim, 1>{}([&](auto j) {
309 move_on_dim_(i) &=
310 ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
311 });
312 });
313
314 // backward condition
315 static_for<0, nDim, 1>{}([&](auto i) {
316 bool tmp = ordered_src_access_idx[i] == ordered_src_access_lengths[i] - 1 &&
317 ordered_src_access_idx[i] > 0;
318 static_for<i + 1, nDim, 1>{}([&](auto j) {
319 tmp &= ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
320 });
321 move_on_dim_(i) |= tmp;
322 });
323
324 return move_on_dim_;
325 }();
326
327 // calculate src data index and make sequence
328 constexpr auto src_data_idx = [&]() {
329 Index ordered_idx;
330
332 [&](auto i) { ordered_idx(i) = ordered_src_access_idx[i]; });
333
334 return container_reorder_given_old2new(ordered_idx, src_dim_access_order);
335 }();
336
337 // make sequence to access vgpr data. Add zero as last element of src_data_idx_seq
338 constexpr auto vgpr_data_idx_seq = generate_sequence_v2(
339 [&](auto i) {
340 if constexpr(i.value < src_data_idx.Size())
341 {
343 }
344 else
345 {
346 return Number<0>{};
347 }
348 },
349 Number<src_data_idx.Size() + 1>{});
350
351 // store element from vgpr to dst buffer
352 dst_buf.template Set<dst_vector_t>(
353 dst_coord_.GetOffset(),
354 true,
355 dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq));
356
357 // For each dimension move fwd, bwd or don't move
358 static_for<0, nDim, 1>{}([&](auto i) {
359 if constexpr(move_on_dim[i])
360 {
361 if constexpr(forward_sweep[i])
362 {
364 dst_desc, dst_coord_, dst_forward_steps[src_dim_access_order[i]]);
365 }
366 else
367 {
369 dst_desc, dst_coord_, dst_backward_steps[src_dim_access_order[i]]);
370 }
371 }
372 });
373 });
374 }
375
376 __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
377 {
378 const auto adjusted_step = make_tensor_coordinate_step(src_desc, step);
379 move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
380 }
381
382 private:
383 // descriptor of vgpr data
384 __device__ static constexpr auto GetThreadScratchDataDescriptor()
385 {
386 constexpr auto access_lengths_as_tuple = container_push_back(
387 sequence_to_tuple_of_number(NumberOfIterations{}), Number<VectorSize>{});
388
389 return make_naive_tensor_descriptor_packed(access_lengths_as_tuple);
390 }
391
392 static constexpr auto thread_data_scratch_desc_ = decltype(GetThreadScratchDataDescriptor()){};
393 using ThreadScratchData = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
394 DstData,
395 VectorSize,
396 decltype(thread_data_scratch_desc_),
397 true>;
398
399 ThreadScratchData dvgpr_;
400 SrcCoord src_coord_;
401 DstCoord dst_coord_;
402 const ElementwiseOperation element_op_;
403};
404
405} // namespace ck
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition ck.hpp:268
decltype(ck::declval< T & >().is_pack8_invocable) is_pack8_invocable_t
Definition is_detected.hpp:43
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc &, const VisibleIndex &idx_diff_visible, UpdateLowerIndexHack)
Definition tensor_description/tensor_descriptor.hpp:444
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const TensorCoordStep &coord_step)
Definition tensor_description/tensor_descriptor.hpp:508
__host__ __device__ constexpr auto container_push_back(const Array< TData, NSize > &a, const TData &x)
Definition utility/container_helper.hpp:18
@ Set
Definition ck.hpp:278
decltype(ck::declval< T & >().is_pack4_invocable) is_pack4_invocable_t
Definition is_detected.hpp:40
__host__ __device__ constexpr bool coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition tensor_description/tensor_descriptor.hpp:560
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition utility/container_helper.hpp:380
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__host__ __device__ constexpr auto container_reorder_given_old2new(const Array< TData, NSize > &old_array, Sequence< IRs... > old2new)
Definition utility/container_helper.hpp:54
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
decltype(ck::declval< T & >().is_pack2_invocable) is_pack2_invocable_t
Definition is_detected.hpp:37
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition tensor_description/tensor_descriptor.hpp:407
__host__ __device__ constexpr auto container_reorder_given_new2old(const Array< TData, NSize > &old_array, Sequence< IRs... >)
Definition utility/container_helper.hpp:43
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
typename vector_type_maker< T, N >::type vector_type_maker_t
Definition dtype_vector.hpp:54
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
decltype(make_tensor_coordinate(DstDesc{}, Index{})) DstCoord
Definition thread_group_tensor_slice_transfer_global.hpp:38
__device__ ThreadGroupTransferGlobal(const SrcDesc &src_desc, const DstDesc &dst_desc, const Index &src_block_slice_origin, const Index &dst_block_slice_origin, const ElementwiseOperation &element_op)
Definition thread_group_tensor_slice_transfer_global.hpp:40
static constexpr auto I6
Definition thread_group_tensor_slice_transfer_global.hpp:33
static constexpr auto I1
Definition thread_group_tensor_slice_transfer_global.hpp:28
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_global.hpp:35
__device__ void RunRead(const SrcDesc &src_desc, const GridBufferType &grid_buf)
Definition thread_group_tensor_slice_transfer_global.hpp:52
static constexpr auto I2
Definition thread_group_tensor_slice_transfer_global.hpp:29
static constexpr auto I0
Definition thread_group_tensor_slice_transfer_global.hpp:27
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_global.hpp:376
static constexpr auto I4
Definition thread_group_tensor_slice_transfer_global.hpp:31
static constexpr auto I3
Definition thread_group_tensor_slice_transfer_global.hpp:30
static constexpr auto I5
Definition thread_group_tensor_slice_transfer_global.hpp:32
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_global.hpp:36
__device__ void RunWrite(const DstDesc &dst_desc, BlockBufferType &dst_buf)
Definition thread_group_tensor_slice_transfer_global.hpp:240
decltype(make_tensor_coordinate(SrcDesc{}, Index{})) SrcCoord
Definition thread_group_tensor_slice_transfer_global.hpp:37
Definition functional2.hpp:33
Definition functional3.hpp:97
Definition dtype_vector.hpp:30
Definition dtype_vector.hpp:10