device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp Source File

device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp Source File#

Composable Kernel: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp Source File
device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
10#include "ck/utility/env.hpp"
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
25template <ck::index_t NDimSpatial,
26 typename InDataType,
27 typename WeiDataType,
28 typename OutDataType,
29 typename AccDataType,
30 typename InElementwiseOperation,
31 typename WeiElementwiseOperation,
32 typename OutElementwiseOperation,
33 ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
34 ck::index_t BlockSize,
35 ck::index_t MPerBlock,
36 ck::index_t NPerBlock,
37 ck::index_t K0PerBlock,
38 ck::index_t K1,
39 index_t M1PerThread,
40 index_t N1PerThread,
41 index_t KPerThread,
42 typename M1N1ThreadClusterM1Xs,
43 typename M1N1ThreadClusterN1Xs,
44 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
45 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
46 typename ABlockTransferThreadClusterArrangeOrder,
47 typename ABlockTransferSrcAccessOrder,
48 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
49 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
50 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
51 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
52 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
53 typename BBlockTransferThreadClusterArrangeOrder,
54 typename BBlockTransferSrcAccessOrder,
55 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
56 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
57 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
58 typename CThreadTransferSrcDstAccessOrder,
59 index_t CThreadTransferSrcDstVectorDim,
60 index_t CThreadTransferDstScalarPerVector>
62 : public DeviceConvBwdData<
63 NDimSpatial,
64 ck::tuple_element_t<NDimSpatial - 1,
65 ck::Tuple<ck::tensor_layout::convolution::NWC,
66 ck::tensor_layout::convolution::NHWC,
67 ck::tensor_layout::convolution::NDHWC>>,
68 ck::tuple_element_t<NDimSpatial - 1,
69 ck::Tuple<ck::tensor_layout::convolution::KXC,
70 ck::tensor_layout::convolution::KYXC,
71 ck::tensor_layout::convolution::KZYXC>>,
72 ck::tuple_element_t<NDimSpatial - 1,
73 ck::Tuple<ck::tensor_layout::convolution::NWK,
74 ck::tensor_layout::convolution::NHWK,
75 ck::tensor_layout::convolution::NDHWK>>,
76 InDataType,
77 WeiDataType,
78 OutDataType,
79 InElementwiseOperation,
80 WeiElementwiseOperation,
81 OutElementwiseOperation>
82{
84
85 using ADataType = OutDataType;
86 using BDataType = WeiDataType;
87 using CDataType = InDataType;
88
89 // TODO make A/B datatype different
90 using ABDataType = InDataType;
91
92 static constexpr auto I0 = Number<0>{};
93 static constexpr auto I1 = Number<1>{};
94 static constexpr auto I2 = Number<2>{};
95 static constexpr auto I3 = Number<3>{};
96 static constexpr auto I4 = Number<4>{};
97 static constexpr auto I5 = Number<5>{};
98 static constexpr auto I6 = Number<6>{};
99 static constexpr auto I7 = Number<7>{};
100
101 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
102 static auto
104 ck::index_t K,
105 ck::index_t C,
106 std::vector<ck::index_t> input_spatial_lengths,
107 std::vector<ck::index_t> filter_spatial_lengths,
108 std::vector<ck::index_t> output_spatial_lengths,
109 std::vector<ck::index_t> conv_filter_strides,
110 std::vector<ck::index_t> conv_filter_dilations,
111 std::vector<ck::index_t> input_left_pads,
112 std::vector<ck::index_t> input_right_pads,
113 std::vector<ck::index_t> tildes)
114 {
115 using namespace ck;
116
117 index_t i_xtilde = tildes[0];
118
119 const index_t Wi = input_spatial_lengths[0];
120 const index_t Wo = output_spatial_lengths[0];
121 const index_t X = filter_spatial_lengths[0];
122 const index_t InLeftPadW = input_left_pads[0];
123 const index_t InRightPadW = input_right_pads[0];
124 const index_t ConvStrideW = conv_filter_strides[0];
125 const index_t ConvDilationW = conv_filter_dilations[0];
126
127 const auto K0 = K / K1;
128
129 const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
130
131 if constexpr(ConvBackwardDataSpecialization ==
133 {
134 // A: output tensor
135 const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
141
142 // B: weight tensor
143 const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
149
150 // C: input tensor
151 const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
152 in_n_wi_c_grid_desc,
154 make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
158
159 const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
160 in_n_x_wo_c_grid_desc,
166
167 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
168 wei_gemmk0_gemmn_gemmk1_grid_desc,
169 in_gemmm_gemmn_grid_desc);
170 }
171 else
172 {
173 const auto out_n_wo_k_grid_desc =
175 const auto wei_k_x_c_grid_desc =
177
178 const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
179
180 const auto XTilde = ConvStrideW / GcdStrideDilationW;
181
182 const auto XDot = math::integer_divide_ceil(X, XTilde);
183
184 const auto WTilde =
185 Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
186
187 // only work on HTilde and WTilde that contribute to non-padding area of input tensor
188 const auto IWTildeSliceBegin = math::integer_divide_floor(
189 math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
190
191 const auto IWTildeSliceEnd = math::min(
192 WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
193
194 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
195
196 // GemmK is different for each GEMM
197 const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
198
199 // A: output tensor
200 const auto out_n_wop_k_grid_desc = transform_tensor_descriptor(
201 out_n_wo_k_grid_desc,
207
208 const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
209 out_n_wop_k_grid_desc,
212 make_embed_transform(make_tuple(XDot, WTilde),
213 make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
217
218 const auto out_n_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor(
219 out_n_xdot_wtilde_k_grid_desc,
221 make_slice_transform(XDot, I0, XDotSlice),
222 make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
226
227 const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
228 out_n_xdotslice_wtildeslice_k0_k1_grid_desc,
230 make_merge_transform(make_tuple(N, WTildeSlice)),
234
235 // B weight tensor
236 const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
237 wei_k_x_c_grid_desc,
239 make_embed_transform(make_tuple(XDot, XTilde),
240 make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
244
245 const auto wei_k0_k1_xdotslice_c_grid_desc = transform_tensor_descriptor(
246 wei_k_xdot_xtilde_c_grid_desc,
248 make_slice_transform(XDot, I0, XDotSlice),
249 make_freeze_transform(i_xtilde),
253
254 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
255 wei_k0_k1_xdotslice_c_grid_desc,
261
262 // C: input tensor
263 const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
264 in_n_wi_c_grid_desc,
266 make_pad_transform(Wi, InLeftPadW, InRightPadW),
270
271 const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
272 in_n_wip_c_grid_desc,
274 make_embed_transform(make_tuple(XTilde, WTilde),
275 make_tuple(ConvDilationW, ConvStrideW)),
279
280 const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor(
281 in_n_xtilde_wtilde_c_grid_desc,
283 make_freeze_transform(i_xtilde),
284 make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
288
289 const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
290 in_n_wtildeslice_c_grid_desc,
295
296 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
297 wei_gemmk0_gemmn_gemmk1_grid_desc,
298 in_gemmm_gemmn_grid_desc);
299 }
300
301 } // function end
302 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
303 static auto
305 ck::index_t K,
306 ck::index_t C,
307 std::vector<ck::index_t> input_spatial_lengths,
308 std::vector<ck::index_t> filter_spatial_lengths,
309 std::vector<ck::index_t> output_spatial_lengths,
310 std::vector<ck::index_t> conv_filter_strides,
311 std::vector<ck::index_t> conv_filter_dilations,
312 std::vector<ck::index_t> input_left_pads,
313 std::vector<ck::index_t> input_right_pads,
314 std::vector<ck::index_t> tildes)
315 {
316 using namespace ck;
317
318 index_t i_ytilde = tildes[0];
319 index_t i_xtilde = tildes[1];
320
321 const index_t Hi = input_spatial_lengths[0];
322 const index_t Wi = input_spatial_lengths[1];
323
324 const index_t Ho = output_spatial_lengths[0];
325 const index_t Wo = output_spatial_lengths[1];
326
327 const index_t Y = filter_spatial_lengths[0];
328 const index_t X = filter_spatial_lengths[1];
329
330 const index_t InLeftPadH = input_left_pads[0];
331 const index_t InLeftPadW = input_left_pads[1];
332
333 const index_t InRightPadH = input_right_pads[0];
334 const index_t InRightPadW = input_right_pads[1];
335
336 const index_t ConvStrideH = conv_filter_strides[0];
337 const index_t ConvStrideW = conv_filter_strides[1];
338
339 const index_t ConvDilationH = conv_filter_dilations[0];
340 const index_t ConvDilationW = conv_filter_dilations[1];
341
342 const auto K0 = K / K1;
343
344 const auto out_n_ho_wo_k_grid_desc =
346 const auto wei_k_y_x_c_grid_desc =
348 const auto in_n_hi_wi_c_grid_desc =
350
351 if constexpr(ConvBackwardDataSpecialization ==
353 {
354 // A: output tensor
355 const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
361
362 // B: weight tensor
363 const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
369
370 // C: input tensor
371 const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
372 in_n_hi_wi_c_grid_desc,
374 make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
375 make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
379
380 const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
381 in_n_y_ho_x_wo_c_grid_desc,
388
389 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
390 wei_gemmk0_gemmn_gemmk1_grid_desc,
391 in_gemmm_gemmn_grid_desc);
392 }
393 else
394 {
395 const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
396 const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
397
398 const auto YTilde = ConvStrideH / GcdStrideDilationH;
399 const auto XTilde = ConvStrideW / GcdStrideDilationW;
400
401 const auto YDot = math::integer_divide_ceil(Y, YTilde);
402 const auto XDot = math::integer_divide_ceil(X, XTilde);
403
404 const auto HTilde =
405 Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
406 const auto WTilde =
407 Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
408
409 // only work on HTilde and WTilde that contribute to non-padding area of input tensor
410 const auto IHTildeSliceBegin = math::integer_divide_floor(
411 math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
412 const auto IWTildeSliceBegin = math::integer_divide_floor(
413 math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
414
415 const auto IHTildeSliceEnd = math::min(
416 HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
417 const auto IWTildeSliceEnd = math::min(
418 WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
419
420 const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
421 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
422
423 // GemmK is different for each GEMM
424 const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
425 const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
426
427 // A: output tensor
428 const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
429 out_n_ho_wo_k_grid_desc,
436
437 const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
438 out_n_hop_wop_k_grid_desc,
441 make_embed_transform(make_tuple(YDot, HTilde),
442 make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
443 make_embed_transform(make_tuple(XDot, WTilde),
444 make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
448
449 const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
451 out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
453 make_slice_transform(YDot, I0, YDotSlice),
454 make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
455 make_slice_transform(XDot, I0, XDotSlice),
456 make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
459 Sequence<1>{},
460 Sequence<2>{},
461 Sequence<3>{},
462 Sequence<4>{},
463 Sequence<5>{}),
465 Sequence<1>{},
466 Sequence<2>{},
467 Sequence<3>{},
468 Sequence<4>{},
469 Sequence<5, 6>{}));
470
471 const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
472 out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
473 make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
474 make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
478
479 // B weight tensor
480 const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
481 wei_k_y_x_c_grid_desc,
483 make_embed_transform(make_tuple(YDot, YTilde),
484 make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
485 make_embed_transform(make_tuple(XDot, XTilde),
486 make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
490
491 const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
492 transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
494 make_slice_transform(YDot, I0, YDotSlice),
495 make_slice_transform(XDot, I0, XDotSlice),
496 make_freeze_transform(i_ytilde),
497 make_freeze_transform(i_xtilde),
500 Sequence<1>{},
501 Sequence<3>{},
502 Sequence<2>{},
503 Sequence<4>{},
504 Sequence<5>{}),
506 Sequence<2>{},
507 Sequence<3>{},
508 Sequence<>{},
509 Sequence<>{},
510 Sequence<4>{}));
511
512 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
513 wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
514 make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
519
520 // C: input tensor
521 const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
522 in_n_hi_wi_c_grid_desc,
524 make_pad_transform(Hi, InLeftPadH, InRightPadH),
525 make_pad_transform(Wi, InLeftPadW, InRightPadW),
529
530 const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
531 in_n_hip_wip_c_grid_desc,
533 make_embed_transform(make_tuple(YTilde, HTilde),
534 make_tuple(ConvDilationH, ConvStrideH)),
535 make_embed_transform(make_tuple(XTilde, WTilde),
536 make_tuple(ConvDilationW, ConvStrideW)),
540
541 const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
542 in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
544 make_freeze_transform(i_ytilde),
545 make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
546 make_freeze_transform(i_xtilde),
547 make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
550 Sequence<1>{},
551 Sequence<2>{},
552 Sequence<3>{},
553 Sequence<4>{},
554 Sequence<5>{}),
556 Sequence<>{},
557 Sequence<1>{},
558 Sequence<>{},
559 Sequence<2>{},
560 Sequence<3>{}));
561
562 const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
563 in_n_htildeslice_wtildeslice_c_grid_desc,
564 make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
568
569 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
570 wei_gemmk0_gemmn_gemmk1_grid_desc,
571 in_gemmm_gemmn_grid_desc);
572 }
573
574 } // function end
575
576 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
577 static auto
579 ck::index_t K,
580 ck::index_t C,
581 std::vector<ck::index_t> input_spatial_lengths,
582 std::vector<ck::index_t> filter_spatial_lengths,
583 std::vector<ck::index_t> output_spatial_lengths,
584 std::vector<ck::index_t> conv_filter_strides,
585 std::vector<ck::index_t> conv_filter_dilations,
586 std::vector<ck::index_t> input_left_pads,
587 std::vector<ck::index_t> input_right_pads,
588 std::vector<ck::index_t> tildes)
589 {
590 using namespace ck;
591
592 const index_t i_ztilde = tildes[0];
593 const index_t i_ytilde = tildes[1];
594 const index_t i_xtilde = tildes[2];
595
596 const index_t Di = input_spatial_lengths[0];
597 const index_t Hi = input_spatial_lengths[1];
598 const index_t Wi = input_spatial_lengths[2];
599
600 const index_t Do = output_spatial_lengths[0];
601 const index_t Ho = output_spatial_lengths[1];
602 const index_t Wo = output_spatial_lengths[2];
603
604 const index_t Z = filter_spatial_lengths[0];
605 const index_t Y = filter_spatial_lengths[1];
606 const index_t X = filter_spatial_lengths[2];
607
608 const index_t InLeftPadD = input_left_pads[0];
609 const index_t InLeftPadH = input_left_pads[1];
610 const index_t InLeftPadW = input_left_pads[2];
611
612 const index_t InRightPadD = input_right_pads[0];
613 const index_t InRightPadH = input_right_pads[1];
614 const index_t InRightPadW = input_right_pads[2];
615
616 const index_t ConvStrideD = conv_filter_strides[0];
617 const index_t ConvStrideH = conv_filter_strides[1];
618 const index_t ConvStrideW = conv_filter_strides[2];
619
620 const index_t ConvDilationD = conv_filter_dilations[0];
621 const index_t ConvDilationH = conv_filter_dilations[1];
622 const index_t ConvDilationW = conv_filter_dilations[2];
623
624 const auto K0 = K / K1;
625
626 const auto out_n_do_ho_wo_k_grid_desc =
628 const auto wei_k_z_y_x_c_grid_desc =
630 const auto in_n_di_hi_wi_c_grid_desc =
632
633 if constexpr(ConvBackwardDataSpecialization ==
635 {
636 // A: output tensor
637 const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
639 make_tuple(make_pass_through_transform(N * Do * Ho * Wo),
643
644 // B: weight tensor
645 const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
651
652 // C: input tensor
653 const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
654 in_n_di_hi_wi_c_grid_desc,
656 make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)),
657 make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
658 make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
666 Sequence<7>{}));
667
668 const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
669 in_n_z_do_y_ho_x_wo_c_grid_desc,
673 make_merge_transform(make_tuple(N, Do, Ho, Wo)),
676 Sequence<3>{},
677 Sequence<5>{},
679 Sequence<7>{}),
681
682 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
683 wei_gemmk0_gemmn_gemmk1_grid_desc,
684 in_gemmm_gemmn_grid_desc);
685 }
686 else
687 {
688 const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
689 const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
690 const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
691
692 const auto ZTilde = ConvStrideD / GcdStrideDilationD;
693 const auto YTilde = ConvStrideH / GcdStrideDilationH;
694 const auto XTilde = ConvStrideW / GcdStrideDilationW;
695
696 const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
697 const auto YDot = math::integer_divide_ceil(Y, YTilde);
698 const auto XDot = math::integer_divide_ceil(X, XTilde);
699
700 const auto DTilde =
701 Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
702 const auto HTilde =
703 Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
704 const auto WTilde =
705 Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
706
707 // only work on HTilde and WTilde that contribute to non-padding area of input tensor
708 const auto IDTildeSliceBegin = math::integer_divide_floor(
709 math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
710 const auto IHTildeSliceBegin = math::integer_divide_floor(
711 math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
712 const auto IWTildeSliceBegin = math::integer_divide_floor(
713 math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
714
715 const auto IDTildeSliceEnd = math::min(
716 DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
717 const auto IHTildeSliceEnd = math::min(
718 HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
719 const auto IWTildeSliceEnd = math::min(
720 WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
721
722 const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
723 const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
724 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
725
726 // GemmK is different for each GEMM
727 const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
728 const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
729 const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
730
731 // A: output tensor
732 const auto out_n_dop_hop_wop_k_grid_desc = transform_tensor_descriptor(
733 out_n_do_ho_wo_k_grid_desc,
743
744 const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
746 out_n_dop_hop_wop_k_grid_desc,
749 make_embed_transform(make_tuple(ZDot, DTilde),
750 make_tuple(-ConvDilationD / GcdStrideDilationD, I1)),
751 make_embed_transform(make_tuple(YDot, HTilde),
752 make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
753 make_embed_transform(make_tuple(XDot, WTilde),
754 make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
762 Sequence<7>{}));
763
764 const auto
765 out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
767 out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
769 make_slice_transform(ZDot, I0, ZDotSlice),
770 make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
771 make_slice_transform(YDot, I0, YDotSlice),
772 make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
773 make_slice_transform(XDot, I0, XDotSlice),
774 make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
777 Sequence<1>{},
778 Sequence<2>{},
779 Sequence<3>{},
780 Sequence<4>{},
781 Sequence<5>{},
782 Sequence<6>{},
783 Sequence<7>{}),
785 Sequence<1>{},
786 Sequence<2>{},
787 Sequence<3>{},
788 Sequence<4>{},
789 Sequence<5>{},
790 Sequence<6>{},
791 Sequence<7, 8>{}));
792
793 const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
794 out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
796 make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
797 make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
801
802 // B weight tensor
803 const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc =
805 wei_k_z_y_x_c_grid_desc,
808 make_embed_transform(make_tuple(ZDot, ZTilde),
809 make_tuple(ConvStrideD / GcdStrideDilationD, I1)),
810 make_embed_transform(make_tuple(YDot, YTilde),
811 make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
812 make_embed_transform(make_tuple(XDot, XTilde),
813 make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
821 Sequence<7>{}));
822
823 const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc =
824 transform_tensor_descriptor(wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
826 make_slice_transform(ZDot, I0, ZDotSlice),
827 make_slice_transform(YDot, I0, YDotSlice),
828 make_slice_transform(XDot, I0, XDotSlice),
829 make_freeze_transform(i_ztilde),
830 make_freeze_transform(i_ytilde),
831 make_freeze_transform(i_xtilde),
834 Sequence<1>{},
835 Sequence<3>{},
836 Sequence<5>{},
837 Sequence<2>{},
838 Sequence<4>{},
839 Sequence<6>{},
840 Sequence<7>{}),
842 Sequence<2>{},
843 Sequence<3>{},
844 Sequence<4>{},
845 Sequence<>{},
846 Sequence<>{},
847 Sequence<>{},
848 Sequence<5>{}));
849
850 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
851 wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc,
852 make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
857
858 // C: input tensor
859 const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
860 in_n_di_hi_wi_c_grid_desc,
862 make_pad_transform(Di, InLeftPadD, InRightPadD),
863 make_pad_transform(Hi, InLeftPadH, InRightPadH),
864 make_pad_transform(Wi, InLeftPadW, InRightPadW),
870
871 const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
873 in_n_dip_hip_wip_c_grid_desc,
875 make_embed_transform(make_tuple(ZTilde, DTilde),
876 make_tuple(ConvDilationD, ConvStrideD)),
877 make_embed_transform(make_tuple(YTilde, HTilde),
878 make_tuple(ConvDilationH, ConvStrideH)),
879 make_embed_transform(make_tuple(XTilde, WTilde),
880 make_tuple(ConvDilationW, ConvStrideW)),
888 Sequence<7>{}));
889
890 const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
892 in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
894 make_freeze_transform(i_ztilde),
895 make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
896 make_freeze_transform(i_ytilde),
897 make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
898 make_freeze_transform(i_xtilde),
899 make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
902 Sequence<1>{},
903 Sequence<2>{},
904 Sequence<3>{},
905 Sequence<4>{},
906 Sequence<5>{},
907 Sequence<6>{},
908 Sequence<7>{}),
910 Sequence<>{},
911 Sequence<1>{},
912 Sequence<>{},
913 Sequence<2>{},
914 Sequence<>{},
915 Sequence<3>{},
916 Sequence<4>{}));
917
918 const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
919 in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
921 make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
925
926 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
927 wei_gemmk0_gemmn_gemmk1_grid_desc,
928 in_gemmm_gemmn_grid_desc);
929 }
930
931 } // function end
932
933 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
934 static auto GetABCGridDesc()
935 {
937 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0});
938 }
939
940 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
941 static auto GetABCGridDesc()
942 {
944 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0});
945 }
946
947 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
948 static auto GetABCGridDesc()
949 {
951 1,
952 1,
953 {1, 1, 1},
954 {1, 1, 1},
955 {1, 1, 1},
956 {1, 1, 1},
957 {1, 1, 1},
958 {1, 1, 1},
959 {1, 1, 1},
960 {0, 0, 0});
961 }
962
964
968
969 // GridwiseGemm
972 ADataType,
973 AccDataType,
974 CDataType,
979 MPerBlock,
980 NPerBlock,
981 K0PerBlock,
982 K1,
983 M1PerThread,
984 N1PerThread,
985 KPerThread,
986 M1N1ThreadClusterM1Xs,
987 M1N1ThreadClusterN1Xs,
988 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
989 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
990 ABlockTransferThreadClusterArrangeOrder,
991 ABlockTransferSrcAccessOrder,
992 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
993 ABlockTransferSrcVectorTensorContiguousDimOrder,
994 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
995 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
996 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
997 BBlockTransferThreadClusterArrangeOrder,
998 BBlockTransferSrcAccessOrder,
999 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
1000 BBlockTransferSrcVectorTensorContiguousDimOrder,
1001 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
1002 CThreadTransferSrcDstAccessOrder,
1003 CThreadTransferSrcDstVectorDim,
1004 CThreadTransferDstScalarPerVector>;
1005
1014 // Argument
1015 struct Argument : public BaseArgument
1016 {
1017 Argument(InDataType* p_in_grid,
1018 const WeiDataType* p_wei_grid,
1019 const OutDataType* p_out_grid,
1020 ck::index_t N,
1021 ck::index_t K,
1022 ck::index_t C,
1023 std::vector<ck::index_t> input_spatial_lengths,
1024 std::vector<ck::index_t> filter_spatial_lengths,
1025 std::vector<ck::index_t> output_spatial_lengths,
1026 std::vector<ck::index_t> conv_filter_strides,
1027 std::vector<ck::index_t> conv_filter_dilations,
1028 std::vector<ck::index_t> input_left_pads,
1029 std::vector<ck::index_t> input_right_pads,
1030 InElementwiseOperation in_element_op,
1031 WeiElementwiseOperation wei_element_op,
1032 OutElementwiseOperation out_element_op)
1033 : p_a_grid_{p_out_grid},
1034 p_b_grid_{p_wei_grid},
1035 p_c_grid_{p_in_grid},
1036 a_element_op_{out_element_op},
1037 b_element_op_{wei_element_op},
1038 c_element_op_{in_element_op},
1039 Conv_N_{N},
1040 Conv_K_{K},
1041 Conv_C_{C},
1042 input_spatial_lengths_{input_spatial_lengths},
1043 filter_spatial_lengths_{filter_spatial_lengths},
1044 output_spatial_lengths_{output_spatial_lengths},
1045 conv_filter_strides_{conv_filter_strides},
1046 conv_filter_dilations_{conv_filter_dilations},
1047 input_left_pads_{input_left_pads},
1048 input_right_pads_{input_right_pads}
1049 {
1051 }
1052
1053 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
1055 {
1056 const index_t ConvStrideW = conv_filter_strides_[0];
1057 const index_t ConvDilationW = conv_filter_dilations_[0];
1058 const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
1059 const auto XTilde = ConvStrideW / GcdStrideDilationW;
1060
1061 const index_t X = filter_spatial_lengths_[0];
1062
1063 for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
1064 {
1065 // check slice is valid
1066 const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
1067 if(XDotSlice <= 0)
1068 {
1069 continue;
1070 }
1071
1072 const auto descs =
1074 Conv_N_,
1075 Conv_K_,
1076 Conv_C_,
1084 {i_xtilde});
1085 a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
1086 b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
1087 c_grid_desc_m_n_container_.push_back(descs[I2]);
1088
1089 if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2]))
1090 {
1097
1100 }
1101 }
1102 }
1103 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
1105 {
1106 const index_t ConvStrideH = conv_filter_strides_[0];
1107 const index_t ConvStrideW = conv_filter_strides_[1];
1108
1109 const index_t ConvDilationH = conv_filter_dilations_[0];
1110 const index_t ConvDilationW = conv_filter_dilations_[1];
1111
1112 const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
1113 const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
1114
1115 const auto YTilde = ConvStrideH / GcdStrideDilationH;
1116 const auto XTilde = ConvStrideW / GcdStrideDilationW;
1117
1118 const index_t Y = filter_spatial_lengths_[0];
1119 const index_t X = filter_spatial_lengths_[1];
1120 for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
1121 {
1122 for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
1123 {
1124 // check slice is valid
1125 const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
1126 const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
1127 if(YDotSlice * XDotSlice <= 0)
1128 {
1129 continue;
1130 }
1131
1132 const auto descs =
1134 Conv_N_,
1135 Conv_K_,
1136 Conv_C_,
1144 {i_ytilde, i_xtilde});
1145 a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
1146 b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
1147 c_grid_desc_m_n_container_.push_back(descs[I2]);
1148
1149 if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2]))
1150 {
1157
1160 }
1161 }
1162 }
1163 }
1164 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
1166 {
1167 const index_t ConvStrideD = conv_filter_strides_[0];
1168 const index_t ConvStrideH = conv_filter_strides_[1];
1169 const index_t ConvStrideW = conv_filter_strides_[2];
1170
1171 const index_t ConvDilationD = conv_filter_dilations_[0];
1172 const index_t ConvDilationH = conv_filter_dilations_[1];
1173 const index_t ConvDilationW = conv_filter_dilations_[2];
1174
1175 const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
1176 const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
1177 const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
1178
1179 const auto ZTilde = ConvStrideD / GcdStrideDilationD;
1180 const auto YTilde = ConvStrideH / GcdStrideDilationH;
1181 const auto XTilde = ConvStrideW / GcdStrideDilationW;
1182
1183 const index_t Z = filter_spatial_lengths_[0];
1184 const index_t Y = filter_spatial_lengths_[1];
1185 const index_t X = filter_spatial_lengths_[2];
1186 for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
1187 {
1188 for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
1189 {
1190 for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
1191 {
1192 // check slice is valid
1193 const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
1194 const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
1195 const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
1196 if(ZDotSlice * YDotSlice * XDotSlice <= 0)
1197 {
1198 continue;
1199 }
1200
1201 const auto descs =
1203 Conv_N_,
1204 Conv_K_,
1205 Conv_C_,
1213 {i_ztilde, i_ytilde, i_xtilde});
1214 a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
1215 b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
1216 c_grid_desc_m_n_container_.push_back(descs[I2]);
1217
1218 if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2]))
1219 {
1226
1229 }
1230 }
1231 }
1232 }
1233 }
1234
1238 std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_;
1239 std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_;
1240 std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_;
1241
1242 std::vector<AGridDesc_K0_M0_M1_K1> a_grid_desc_k0_m0_m1_k1_container_;
1243 std::vector<BGridDesc_K0_N0_N1_K1> b_grid_desc_k0_n0_n1_k1_container_;
1244 std::vector<CGridDesc_M0_M10_M11_N0_N10_N11> c_grid_desc_m0_m10_m11_n0_n10_n11_container_;
1245
1246 std::vector<DefaultBlock2CTileMap> block_2_ctile_map_container_;
1247
1248 // element-wise op
1249 OutElementwiseOperation a_element_op_;
1250 WeiElementwiseOperation b_element_op_;
1251 InElementwiseOperation c_element_op_;
1252 // for checking IsSupportedArgument()
1256
1257 std::vector<ck::index_t> input_spatial_lengths_;
1258 std::vector<ck::index_t> filter_spatial_lengths_;
1259 std::vector<ck::index_t> output_spatial_lengths_;
1260 std::vector<ck::index_t> conv_filter_strides_;
1261 std::vector<ck::index_t> conv_filter_dilations_;
1262 std::vector<ck::index_t> input_left_pads_;
1263 std::vector<ck::index_t> input_right_pads_;
1264 };
1265
1266 // Invoker
1267 struct Invoker : public BaseInvoker
1268 {
1270
1271 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
1272 {
1273 float ave_time = 0;
1274 for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
1275 {
1276 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1277 {
1278 std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
1279 << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
1280 << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", "
1281 << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}"
1282 << std::endl;
1283
1284 std::cout << "arg.b_grid_desc_k0_n_k1_container_{"
1285 << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", "
1286 << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", "
1287 << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}"
1288 << std::endl;
1289
1290 std::cout << "arg.c_grid_desc_m_n_container_{ "
1291 << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", "
1292 << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
1293 << std::endl;
1294
1295 std::cout << "arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_( "
1297 << ", "
1299 << ", "
1301 << ", "
1303 << ", "
1305 << ", "
1307 << " ) " << std::endl;
1308 }
1309
1313 {
1314 throw std::runtime_error(
1315 "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
1316 }
1317
1318 const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize(
1320
1321 auto launch_kernel = [&](auto has_main_k_block_loop,
1322 auto has_double_tail_k_block_loop) {
1323 constexpr bool has_main_loop = has_main_k_block_loop.value;
1324 constexpr bool has_double_loop = has_double_tail_k_block_loop;
1325
1326 const auto kernel = kernel_gemm_dl_v1r3<
1327 GridwiseGemm,
1328 ADataType, // TODO: distiguish A/B datatype
1329 CDataType,
1334 has_main_loop,
1335 has_double_loop>;
1336
1337 ave_time +=
1338 launch_and_time_kernel(stream_config,
1339 kernel,
1340 dim3(grid_size),
1341 dim3(BlockSize),
1342 0,
1343 arg.p_a_grid_,
1344 arg.p_b_grid_,
1345 arg.p_c_grid_,
1350 };
1351
1352 const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_container_[i].GetLength(I0);
1353 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
1354 const bool has_double_tail_k_block_loop =
1356
1357 if(has_main_k_block_loop && has_double_tail_k_block_loop)
1358 {
1360 }
1361 else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
1362 {
1363 launch_kernel(integral_constant<bool, true>{},
1364 integral_constant<bool, false>{});
1365 }
1366 else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
1367 {
1368 launch_kernel(integral_constant<bool, false>{},
1369 integral_constant<bool, true>{});
1370 }
1371 else
1372 {
1373 launch_kernel(integral_constant<bool, false>{},
1374 integral_constant<bool, false>{});
1375 }
1376 }
1377 return ave_time;
1378 }
1379
1380 float Run(const BaseArgument* p_arg,
1381 const StreamConfig& stream_config = StreamConfig{}) override
1382 {
1383 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
1384 }
1385 };
1386
1387 static constexpr bool IsValidCompilationParameter()
1388 {
1389 // TODO: properly implement this check
1390 return true;
1391 }
1392
1393 static bool IsSupportedArgument(const Argument& arg)
1394 {
1395 // check device
1396 if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
1398 {
1399 return false;
1400 }
1401
1402 if constexpr(ConvBackwardDataSpecialization ==
1404 {
1405 // check if it's 1x1, stride=1 pad = 0 conv
1406 for(int i = 0; i < NDimSpatial; i++)
1407 {
1408 if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
1409 arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
1410 {
1411 return false;
1412 }
1413 }
1414 }
1415
1416 // matrix A
1417 {
1418 auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{};
1419 if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1)
1420 {
1421 return false;
1422 }
1423 if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0)
1424 {
1425 return false;
1426 }
1427
1428 const index_t K = arg.Conv_K_;
1429
1430 if(K % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0)
1431 {
1432 return false;
1433 }
1434 }
1435
1436 // matrix B
1437 {
1438 auto srcLoadLenghts = BBlockTransferThreadSliceLengths_K0_N0_N1_K1{};
1439 auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{};
1440 if(srcVectorLengths[I0] != 1 || srcVectorLengths[I3] != 1)
1441 {
1442 return false;
1443 }
1444 if(srcLoadLenghts[I1] % srcVectorLengths[I1] != 0 ||
1445 srcLoadLenghts[I2] % srcVectorLengths[I2] != 0)
1446 {
1447 return false;
1448 }
1449
1450 const index_t C = arg.Conv_K_;
1451
1452 if(C % (srcVectorLengths[I1] * srcVectorLengths[I2]) != 0)
1453 {
1454 return false;
1455 }
1456 }
1457 // vector store C matrix into global memory
1458 if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0))
1459 {
1460 std::cout << "Not surpport,because: arg.Conv_C_ % CThreadTransferDstScalarPerVector = "
1461 << arg.Conv_C_ % CThreadTransferDstScalarPerVector << std::endl;
1462 return false;
1463 }
1464
1465 // Gridwise GEMM size
1466 for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
1467 {
1471 {
1472 return false;
1473 }
1474 }
1475 return true;
1476 }
1477
1478 bool IsSupportedArgument(const BaseArgument* p_arg) override
1479 {
1480 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
1481 }
1482
1483 static auto MakeArgument(InDataType* p_in_grid,
1484 const WeiDataType* p_wei_grid,
1485 const OutDataType* p_out_grid,
1486 ck::index_t N,
1487 ck::index_t K,
1488 ck::index_t C,
1489 std::vector<ck::index_t> input_spatial_lengths,
1490 std::vector<ck::index_t> filter_spatial_lengths,
1491 std::vector<ck::index_t> output_spatial_lengths,
1492 std::vector<ck::index_t> conv_filter_strides,
1493 std::vector<ck::index_t> conv_filter_dilations,
1494 std::vector<ck::index_t> input_left_pads,
1495 std::vector<ck::index_t> input_right_pads,
1496 InElementwiseOperation in_element_op,
1497 WeiElementwiseOperation wei_element_op,
1498 OutElementwiseOperation out_element_op)
1499 {
1500 return Argument{p_in_grid,
1501 p_wei_grid,
1502 p_out_grid,
1503 N,
1504 K,
1505 C,
1506 input_spatial_lengths,
1507 filter_spatial_lengths,
1508 output_spatial_lengths,
1509 conv_filter_strides,
1510 conv_filter_dilations,
1511 input_left_pads,
1512 input_right_pads,
1513 in_element_op,
1514 wei_element_op,
1515 out_element_op};
1516 }
1517
1518 static auto MakeInvoker() { return Invoker{}; }
1519
1520 std::unique_ptr<BaseArgument>
1521 MakeArgumentPointer(void* p_in_grid,
1522 const void* p_wei_grid,
1523 const void* p_out_grid,
1524 ck::index_t N,
1525 ck::index_t K,
1526 ck::index_t C,
1527 std::vector<ck::index_t> input_spatial_lengths,
1528 std::vector<ck::index_t> filter_spatial_lengths,
1529 std::vector<ck::index_t> output_spatial_lengths,
1530 std::vector<ck::index_t> conv_filter_strides,
1531 std::vector<ck::index_t> conv_filter_dilations,
1532 std::vector<ck::index_t> input_left_pads,
1533 std::vector<ck::index_t> input_right_pads,
1534 InElementwiseOperation in_element_op,
1535 WeiElementwiseOperation wei_element_op,
1536 OutElementwiseOperation out_element_op) override
1537 {
1538 return std::make_unique<Argument>(static_cast<InDataType*>(p_in_grid),
1539 static_cast<const WeiDataType*>(p_wei_grid),
1540 static_cast<const OutDataType*>(p_out_grid),
1541 N,
1542 K,
1543 C,
1544 input_spatial_lengths,
1545 filter_spatial_lengths,
1546 output_spatial_lengths,
1547 conv_filter_strides,
1548 conv_filter_dilations,
1549 input_left_pads,
1550 input_right_pads,
1551 in_element_op,
1552 wei_element_op,
1553 out_element_op);
1554 }
1555
1556 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1557 {
1558 return std::make_unique<Invoker>(Invoker{});
1559 }
1560
1561 std::string GetTypeString() const override
1562 {
1563 auto str = std::stringstream();
1564
1565 // clang-format off
1566 str << "DeviceConvNdBwdDataNwcKxcNwk_Dl"
1567 << "<"
1568 << BlockSize << ", "
1569 << MPerBlock << ", "
1570 << NPerBlock << ", "
1571 << K0PerBlock << ", "
1572 << K1
1573 << ">";
1574 if constexpr(ConvBackwardDataSpecialization ==
1576
1577 str<< " Filter1x1Stride1Pad0";
1578 }
1579
1580
1581 return str.str();
1582 }
1583};
1584
1585} // namespace device
1586} // namespace tensor_operation
1587} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
Definition utility/math.hpp:154
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition convolution_backward_data_specialization.hpp:8
ConvolutionBackwardDataSpecialization
Definition convolution_backward_data_specialization.hpp:11
@ Filter1x1Stride1Pad0
Definition convolution_backward_data_specialization.hpp:13
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition multi_index_transform_helper.hpp:163
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
__global__ void kernel_gemm_dl_v1r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_dl_v1r3.hpp:33
bool is_gfx103_supported()
Definition host_utility/device_prop.hpp:120
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_dl_v1r3.hpp:93
Definition utility/sequence.hpp:43
Definition utility/integral_constant.hpp:20
Definition device_base.hpp:197
Definition device_conv_bwd_data.hpp:25
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1016
const BDataType * p_b_grid_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1236
std::vector< CGridDesc_M0_M10_M11_N0_N10_N11 > c_grid_desc_m0_m10_m11_n0_n10_n11_container_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1244
index_t Conv_N_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1253
std::vector< BGridDesc_K0_N0_N1_K1 > b_grid_desc_k0_n0_n1_k1_container_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1243
std::vector< CGridDesc_M_N > c_grid_desc_m_n_container_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1240
std::vector< ck::index_t > filter_spatial_lengths_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1258
index_t Conv_C_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1255
InElementwiseOperation c_element_op_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1251
std::vector< ck::index_t > output_spatial_lengths_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1259
std::vector< ck::index_t > input_left_pads_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1262
std::vector< ck::index_t > conv_filter_strides_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1260
std::vector< ck::index_t > input_spatial_lengths_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1257
WeiElementwiseOperation b_element_op_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1250
index_t Conv_K_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1254
std::vector< BGridDesc_K0_N_K1 > b_grid_desc_k0_n_k1_container_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1239
CDataType * p_c_grid_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1237
std::vector< ck::index_t > input_right_pads_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1263
void CreateABCDesc()
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1054
std::vector< AGridDesc_K0_M_K1 > a_grid_desc_k0_m_k1_container_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1238
const ADataType * p_a_grid_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1235
std::vector< DefaultBlock2CTileMap > block_2_ctile_map_container_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1246
Argument(InDataType *p_in_grid, const WeiDataType *p_wei_grid, const OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op)
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1017
OutElementwiseOperation a_element_op_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1249
std::vector< ck::index_t > conv_filter_dilations_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1261
std::vector< AGridDesc_K0_M0_M1_K1 > a_grid_desc_k0_m0_m1_k1_container_
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1242
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1268
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1380
DeviceOp::Argument Argument
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1269
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1271
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:82
InDataType CDataType
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:87
InDataType ABDataType
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:90
static constexpr auto I3
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:95
static constexpr bool IsValidCompilationParameter()
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1387
DeviceConvNdBwdDataNwcKxcNwk_Dl DeviceOp
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:83
decltype(GetABCGridDesc< NDimSpatial >()) ABCGridDescs
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:963
static constexpr auto I7
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:99
static constexpr auto I5
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:97
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1556
std::string GetTypeString() const override
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1561
static bool IsSupportedArgument(const Argument &arg)
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1393
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, std::vector< ck::index_t > tildes)
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:103
static auto MakeArgument(InDataType *p_in_grid, const WeiDataType *p_wei_grid, const OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op)
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1483
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_K0_N_K1
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:966
static constexpr auto I2
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:94
std::unique_ptr< BaseArgument > MakeArgumentPointer(void *p_in_grid, const void *p_wei_grid, const void *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) override
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1521
static auto MakeInvoker()
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1518
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})) DefaultBlock2CTileMap
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1012
static constexpr auto I4
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:96
OutDataType ADataType
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:85
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:967
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})) BGridDesc_K0_N0_N1_K1
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1008
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_K0_M_K1
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:965
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1478
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})) AGridDesc_K0_M0_M1_K1
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1006
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1010
GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemm
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:970
static constexpr auto I6
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:98
static constexpr auto I0
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:92
static constexpr auto I1
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:93
WeiDataType BDataType
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:86
static auto GetABCGridDesc()
Definition device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:934
#define CK_ENV(name)
Definition utility/env.hpp:129