device_reduce_threadwise_multi_d.hpp Source File

device_reduce_threadwise_multi_d.hpp Source File#

Composable Kernel: device_reduce_threadwise_multi_d.hpp Source File
device_reduce_threadwise_multi_d.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8#include <array>
9
14
16
17namespace ck {
18namespace tensor_operation {
19namespace device {
20
21template <typename InDataType,
22 typename DsDataType,
23 typename AccDataType,
24 typename OutDataType,
25 index_t Rank,
26 index_t NumReduceDim,
27 typename ReduceOperation,
28 typename InElementwiseOperation,
29 typename OutElementwiseOperation,
30 index_t BlockSize,
31 index_t MThreadSliceSize,
32 index_t KThreadSliceSize,
33 index_t InSrcVectorDim,
34 index_t InSrcVectorSize,
35 index_t OutDstVectorSize,
36 typename DsVectorSizeSequence>
38 DsDataType,
39 AccDataType,
40 OutDataType,
41 Rank,
42 NumReduceDim,
43 ReduceOperation,
44 InElementwiseOperation,
45 OutElementwiseOperation>
46
47{
48 static_assert(Rank <= 12, "Bigger Rank size is not supported!");
49
50 static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
51 (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
52 (MThreadSliceSize % OutDstVectorSize == 0),
53 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
54
56
57 static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
58
59 static constexpr index_t NumDTensor = DsDataType::Size();
60
61 static constexpr index_t NumSrcDim = Rank;
62 static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
63 static constexpr bool reduceAllDim = (NumInvariantDim == 0);
64
65 static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize;
66 static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize;
67
68 static auto MakeSrc2dDescriptor(const std::array<index_t, Rank>& inLengths,
69 const std::array<index_t, Rank>& inStrides)
70 {
71 const auto tupleSrcLengths =
72 generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
73 const auto tupleSrcStrides =
74 generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
75
76 const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
77
78 const auto in_grid_desc_m_k = [&]() {
79 if constexpr(reduceAllDim)
80 {
81 const auto one_dim_inDesc = transform_tensor_descriptor(
82 inDesc,
83 make_tuple(make_merge_transform(tupleSrcLengths)),
86
87 return transform_tensor_descriptor(one_dim_inDesc,
89 1, one_dim_inDesc.GetLength(Number<0>{})))),
92 }
93 else
94 {
95 using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
97
98 const auto reduceDimLengths = generate_tuple(
99 [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
100 const auto invariantDimLengths =
101 generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
102
104 inDesc,
105 make_tuple(make_merge_transform(invariantDimLengths),
106 make_merge_transform(reduceDimLengths)),
107 make_tuple(InvariantDims{}, ReduceDims{}),
109 }
110 }();
111
112 const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
113 const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
114
115 const auto inPad_M =
116 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
117 const auto inPad_K =
118 math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
119
120 auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
121 in_grid_desc_m_k,
122 make_tuple(make_right_pad_transform(invariantLength, inPad_M),
123 make_right_pad_transform(reduceLength, inPad_K)),
126
127 return (in_grid_desc_m_k_padded);
128 };
129
130 static auto MakeDst1dDescriptor(const std::array<index_t, NumDstDim>& outLengths,
131 const std::array<index_t, NumDstDim>& outStrides)
132 {
133 const auto tupleDstLengths =
134 generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
135 const auto tupleDstStrides =
136 generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
137
138 auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
139
140 auto out_grid_desc_m = transform_tensor_descriptor(
141 outDesc,
142 make_tuple(make_merge_transform(tupleDstLengths)),
145
146 const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
147
148 const auto outPad =
149 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
150
151 auto out_grid_desc_m_padded = transform_tensor_descriptor(
152 out_grid_desc_m,
153 make_tuple(make_right_pad_transform(invariantLength, outPad)),
156 return (out_grid_desc_m_padded);
157 };
158
159 static auto
160 MakeDsDescriptor(const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths,
161 std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides)
162 {
163 return generate_tuple(
164 [&](auto i) {
166 DsStrides[i]);
167 },
169 }
170
171 using InGridDesc_M_K = decltype(MakeSrc2dDescriptor({}, {}));
172 using OutGridDesc_M = decltype(MakeDst1dDescriptor({}, {}));
173 using DsGridDesc_M = decltype(MakeDsDescriptor({}, {}));
174
177 DsDataType,
178 OutDataType,
179 AccDataType,
183 ReduceOperation,
184 InElementwiseOperation,
185 OutElementwiseOperation,
187 BlockSize,
188 MThreadSliceSize,
189 KThreadSliceSize,
190 InSrcVectorDim,
191 InSrcVectorSize,
192 OutDstVectorSize,
193 DsVectorSizeSequence>;
194
196
197 struct Argument : public BaseArgument
198 {
199 Argument(const std::array<index_t, Rank> inLengths,
200 const std::array<index_t, Rank> inStrides,
201 const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths,
202 const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides,
203 const std::array<index_t, NumDstDim> outLengths,
204 const std::array<index_t, NumDstDim> outStrides,
205 const std::array<int, NumReduceDim> reduceDims,
206 const InDataType* in_dev,
207 const std::array<const void*, NumDTensor> ds_dev,
208 OutDataType* out_dev,
209 const InElementwiseOperation in_elementwise_op,
210 const OutElementwiseOperation out_elementwise_op)
211 : DsLengths_{DsLengths},
212 DsStrides_{DsStrides},
213 outLengths_{outLengths},
214 outStrides_{outStrides},
215 in_dev_{in_dev},
216 out_dev_{out_dev},
217 in_elementwise_op_{in_elementwise_op},
218 out_elementwise_op_{out_elementwise_op}
219 {
222
225
226 if constexpr(NumInvariantDim == 0)
228 else
230
232
234
237
238 static_for<0, NumDTensor, 1>{}([&](auto i) {
239 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
240 p_ds_grid_(i) = static_cast<const DDataType*>(ds_dev[i]);
241 });
242
243 ds_grid_desc_m_ = MakeDsDescriptor(DsLengths, DsStrides);
244 }
245
246 std::array<index_t, Rank> inLengths_;
247 std::array<index_t, Rank> inStrides_;
248
249 std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths_;
250 std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides_;
251
252 std::array<index_t, NumDstDim> outLengths_;
253 std::array<index_t, NumDstDim> outStrides_;
254
255 const InDataType* in_dev_;
256 OutDataType* out_dev_;
257
259
260 InElementwiseOperation in_elementwise_op_;
261 OutElementwiseOperation out_elementwise_op_;
262
264
269
271 size_t gridSize;
272 };
273
274 struct Invoker : public BaseInvoker
275 {
276 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
277 {
278 const auto in_grid_desc_m_k =
280 const auto out_grid_desc_m =
282
283 float avg_time = 0;
284
286 InDataType,
287 OutDataType,
288 AccDataType,
292 InElementwiseOperation,
293 OutElementwiseOperation,
295
296 avg_time = launch_and_time_kernel(stream_config,
297 kernel,
298 dim3(arg.gridSize),
299 dim3(BlockSize),
300 0,
301 in_grid_desc_m_k,
302 arg.ds_grid_desc_m_,
303 out_grid_desc_m,
306 arg.in_dev_,
307 arg.p_ds_grid_,
308 arg.out_dev_);
309
310 return (avg_time);
311 };
312
313 float Run(const BaseArgument* p_arg,
314 const StreamConfig& stream_config = StreamConfig{}) override
315 {
316 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
317 };
318 };
319
320 bool IsSupportedArgument(const BaseArgument* p_arg) override
321 {
322 const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
323
324 if constexpr(InSrcVectorDim == 0)
325 {
326 if constexpr(NumInvariantDim == 0)
327 {
328 return (false);
329 }
330 else
331 {
332 if(pArg->inStrides_[NumInvariantDim - 1] != 1)
333 return (false);
334
335 if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
336 return (false);
337 };
338 }
339 else
340 {
341 if(pArg->inStrides_[Rank - 1] != 1)
342 return (false);
343
344 if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
345 return (false);
346 };
347
348 // To improve
349 if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
350 return (false);
351
352 std::cerr << "reduce_total_length = " << pArg->reduce_total_length
353 << " KThreadSliceSize = " << KThreadSliceSize << std::endl;
354
355 // cases with big reduce_total_length should be handled by Blockwise kernel
356 if(pArg->reduce_total_length / KThreadSliceSize >= 32)
357 return (false);
358
359 return (true);
360 };
361
362 std::unique_ptr<BaseArgument>
363 MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
364 const std::array<index_t, Rank> inStrides,
365 const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths,
366 const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides,
367 const std::array<index_t, NumDstDim> outLengths,
368 const std::array<index_t, NumDstDim> outStrides,
369 const std::array<int, NumReduceDim> reduceDims,
370 const void* in_dev,
371 const std::array<const void*, NumDTensor> ds_dev,
372 void* out_dev,
373 const InElementwiseOperation in_elementwise_op,
374 const OutElementwiseOperation out_elementwise_op) override
375 {
376 return std::make_unique<Argument>(inLengths,
377 inStrides,
378 DsLengths,
379 DsStrides,
380 outLengths,
381 outStrides,
382 reduceDims,
383 static_cast<const InDataType*>(in_dev),
384 ds_dev,
385 static_cast<OutDataType*>(out_dev),
386 in_elementwise_op,
387 out_elementwise_op);
388 };
389
390 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
391 {
392 return std::make_unique<Invoker>();
393 };
394
395 std::string GetTypeString() const override
396 {
397 auto str = std::stringstream();
398
399 // clang-format off
400 str << "DeviceReduceThreadWiseMultiD<" << BlockSize << ",";
401 str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ",";
402 str << "K_C" << 1 << "_S" << KThreadSliceSize << ",";
403 str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
404 // clang-format on
405
406 return str.str();
407 }
408};
409
410} // namespace device
411} // namespace tensor_operation
412} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
Definition convolution_backward_data_specialization.hpp:8
std::pair< long_index_t, long_index_t > get_2d_lengths(const std::vector< index_t > &inLengths)
Definition device_reduce_common.hpp:20
std::vector< index_t > shuffle_tensor_dimensions(const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims)
Definition device_reduce_common.hpp:75
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__global__ void kernel_reduce_threadwise_multi_d(const InGridDesc_M_K in_grid_desc_m_k, const DsGridDesc_M ds_grid_desc_m, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const OutElementwiseOperation out_elementwise_op, const InDataType *const __restrict__ p_in_value_global, const DsGridPointer p_ds_value_global, OutDataType *const __restrict__ p_out_value_global)
Definition gridwise_2d_reduction_threadwise_multi_d.hpp:28
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
signed int int32_t
Definition stdint.h:123
Definition ck/stream_config.hpp:10
Definition gridwise_2d_reduction_threadwise_multi_d.hpp:66
decltype(MakeDsGridPointer()) DsGridPointer
Definition gridwise_2d_reduction_threadwise_multi_d.hpp:98
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_reduce_multi_d.hpp:26
Definition device_reduce_threadwise_multi_d.hpp:198
long_index_t invariant_total_length
Definition device_reduce_threadwise_multi_d.hpp:267
OutElementwiseOperation out_elementwise_op_
Definition device_reduce_threadwise_multi_d.hpp:261
size_t gridSize
Definition device_reduce_threadwise_multi_d.hpp:271
std::array< index_t, NumDstDim > outLengths_
Definition device_reduce_threadwise_multi_d.hpp:252
const InDataType * in_dev_
Definition device_reduce_threadwise_multi_d.hpp:255
std::array< std::array< index_t, NumDstDim >, NumDTensor > DsStrides_
Definition device_reduce_threadwise_multi_d.hpp:250
DsGridPointer p_ds_grid_
Definition device_reduce_threadwise_multi_d.hpp:258
index_t invariant_lowest_length
Definition device_reduce_threadwise_multi_d.hpp:265
std::array< index_t, Rank > inStrides_
Definition device_reduce_threadwise_multi_d.hpp:247
index_t reduce_lowest_length
Definition device_reduce_threadwise_multi_d.hpp:266
int numBlockTileIteration
Definition device_reduce_threadwise_multi_d.hpp:270
std::array< index_t, Rank > inLengths_
Definition device_reduce_threadwise_multi_d.hpp:246
InElementwiseOperation in_elementwise_op_
Definition device_reduce_threadwise_multi_d.hpp:260
OutDataType * out_dev_
Definition device_reduce_threadwise_multi_d.hpp:256
std::array< std::array< index_t, NumDstDim >, NumDTensor > DsLengths_
Definition device_reduce_threadwise_multi_d.hpp:249
long_index_t reduce_total_length
Definition device_reduce_threadwise_multi_d.hpp:268
std::array< index_t, NumDstDim > outStrides_
Definition device_reduce_threadwise_multi_d.hpp:253
DsGridDesc_M ds_grid_desc_m_
Definition device_reduce_threadwise_multi_d.hpp:263
Argument(const std::array< index_t, Rank > inLengths, const std::array< index_t, Rank > inStrides, const std::array< std::array< index_t, NumDstDim >, NumDTensor > DsLengths, const std::array< std::array< index_t, NumDstDim >, NumDTensor > DsStrides, const std::array< index_t, NumDstDim > outLengths, const std::array< index_t, NumDstDim > outStrides, const std::array< int, NumReduceDim > reduceDims, const InDataType *in_dev, const std::array< const void *, NumDTensor > ds_dev, OutDataType *out_dev, const InElementwiseOperation in_elementwise_op, const OutElementwiseOperation out_elementwise_op)
Definition device_reduce_threadwise_multi_d.hpp:199
Definition device_reduce_threadwise_multi_d.hpp:275
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_reduce_threadwise_multi_d.hpp:276
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_reduce_threadwise_multi_d.hpp:313
Definition device_reduce_threadwise_multi_d.hpp:47
static auto MakeSrc2dDescriptor(const std::array< index_t, Rank > &inLengths, const std::array< index_t, Rank > &inStrides)
Definition device_reduce_threadwise_multi_d.hpp:68
std::string GetTypeString() const override
Definition device_reduce_threadwise_multi_d.hpp:395
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, Rank > inLengths, const std::array< index_t, Rank > inStrides, const std::array< std::array< index_t, NumDstDim >, NumDTensor > DsLengths, const std::array< std::array< index_t, NumDstDim >, NumDTensor > DsStrides, const std::array< index_t, NumDstDim > outLengths, const std::array< index_t, NumDstDim > outStrides, const std::array< int, NumReduceDim > reduceDims, const void *in_dev, const std::array< const void *, NumDTensor > ds_dev, void *out_dev, const InElementwiseOperation in_elementwise_op, const OutElementwiseOperation out_elementwise_op) override
Definition device_reduce_threadwise_multi_d.hpp:363
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_reduce_threadwise_multi_d.hpp:320
static auto MakeDst1dDescriptor(const std::array< index_t, NumDstDim > &outLengths, const std::array< index_t, NumDstDim > &outStrides)
Definition device_reduce_threadwise_multi_d.hpp:130
GridwiseReduction_mk_to_m_threadwise_multi_d< ReduceDataType, DsDataType, CDataType, GemmAccDataType, InGridDesc_M_K, DsGridDesc_M, OutGridDesc_M, ReduceAdd, PassThrough, OutElementwiseOperation, InMemoryDataOperationEnum::Set, BlockSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize, decltype(DsVectorLengthSequence) > GridwiseReduce
Definition device_reduce_threadwise_multi_d.hpp:175
static auto MakeDsDescriptor(const std::array< std::array< index_t, NumDstDim >, NumDTensor > DsLengths, std::array< std::array< index_t, NumDstDim >, NumDTensor > DsStrides)
Definition device_reduce_threadwise_multi_d.hpp:160
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_reduce_threadwise_multi_d.hpp:390