31#ifdef CK_EXPERIMENTAL_BUILDER
32#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
66template <
typename GridwiseGemm,
71 typename AElementwiseOperation,
72 typename BElementwiseOperation,
73 typename CDEElementwiseOperation,
74 typename AGridDesc_AK0_M_AK1,
75 typename BGridDesc_BK0_N_BK1,
76 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
77 typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
78 typename Block2ETileMap,
79 typename ComputePtrOffsetOfG,
80 typename ComputePtrOffsetOfN,
81 bool HasMainKBlockLoop,
86#if CK_USE_LAUNCH_BOUNDS
89 kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
93 EDataType* __restrict__ p_e_grid,
94 AElementwiseOperation a_element_op,
95 BElementwiseOperation b_element_op,
96 CDEElementwiseOperation cde_element_op,
97 const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
98 const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
99 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
100 ds_grid_desc_mblock_mperblock_nblock_nperblock,
101 const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
102 e_grid_desc_mblock_mperblock_nblock_nperblock_,
103 const Block2ETileMap block_2_ctile_map,
104 const ComputePtrOffsetOfG compute_ptr_offset_of_groups,
105 const ComputePtrOffsetOfN compute_ptr_offset_of_n)
107#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
108 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
111 const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
112 const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
116 const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
117 const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
122 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
124 DsPointer p_ds_grid_grp;
126 static constexpr index_t NumDTensor =
127 DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
129 static_for<0, NumDTensor, 1>{}(
130 [&](
auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; });
132 if constexpr(isMultiA || isMultiB)
134 AsPointer p_as_grid_grp;
135 BsPointer p_bs_grid_grp;
137 const auto& as_group_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx);
142 if constexpr(isMultiA)
144 const auto& as_n_offset = compute_ptr_offset_of_n.GetAsPtrOffset(n_idx);
146 static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size();
147 static_for<0, NumATensor, 1>{}([&](
auto i) {
148 p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + as_n_offset[i];
153 const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
154 static_for<0, 1, 1>{}([&](
auto i) {
155 p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + a_n_offset;
159 const auto& bs_group_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx);
161 static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size();
162 static_for<0, NumBTensor, 1>{}(
163 [&](
auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_group_offset[i]; });
165 GridwiseGemm::template Run<HasMainKBlockLoop>(
169 p_e_grid + e_group_offset + e_n_offset,
176 ds_grid_desc_mblock_mperblock_nblock_nperblock,
177 e_grid_desc_mblock_mperblock_nblock_nperblock_,
197 GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
198 p_as_grid + a_group_offset + a_n_offset,
199 p_bs_grid + b_group_offset + b_n_offset,
201 p_e_grid + e_group_offset + e_n_offset,
208 ds_grid_desc_mblock_mperblock_nblock_nperblock,
209 e_grid_desc_mblock_mperblock_nblock_nperblock_,
218 ignore = a_grid_desc_k0_m_k1;
219 ignore = b_grid_desc_k0_n_k1;
220 ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
221 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
225 ignore = compute_ptr_offset_of_groups;
226 ignore = compute_ptr_offset_of_n;
227 ignore = block_2_ctile_map;
232#ifdef CK_CODE_GEN_RTC
234using is_tuple =
decltype(ck::declval<T&>().IsTuple());
237using is_tuple =
decltype(std::declval<T&>().IsTuple());
263 typename AccDataType,
264 typename CShuffleDataType,
267 typename AElementwiseOperation,
268 typename BElementwiseOperation,
269 typename CDEElementwiseOperation,
283 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
284 typename ABlockTransferThreadClusterArrangeOrder,
285 typename ABlockTransferSrcAccessOrder,
286 index_t ABlockTransferSrcVectorDim,
287 index_t ABlockTransferSrcScalarPerVector,
288 index_t ABlockTransferDstScalarPerVector_AK1,
290 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
291 typename BBlockTransferThreadClusterArrangeOrder,
292 typename BBlockTransferSrcAccessOrder,
293 index_t BBlockTransferSrcVectorDim,
294 index_t BBlockTransferSrcScalarPerVector,
295 index_t BBlockTransferDstScalarPerVector_BK1,
297 index_t CShuffleMXdlPerWavePerShuffle,
298 index_t CShuffleNXdlPerWavePerShuffle,
299 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
300 index_t CDEBlockTransferScalarPerVector_NPerBlock,
301 typename AComputeDataType =
302 decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
307 typename BComputeDataType = AComputeDataType,
320 AElementwiseOperation,
321 BElementwiseOperation,
322 CDEElementwiseOperation,
331 static_assert(NumGroupsToMerge >= 1);
359 (ABlockTransferSrcVectorDim == 1) && (NumGroupsToMerge == 1) &&
372 ConvForwardSpecialization,
381 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
394 template <
typename ALay>
398 using Layout = std::conditional_t<
405 const auto in_gemmmraw_gemmkraw_desc =
406 conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
408 const auto in_gemmm_gemmk_desc =
411 return in_gemmm_gemmk_desc;
414 template <
typename BLay>
418 using Layout = std::conditional_t<
425 const auto wei_gemmnraw_gemmkraw_desc =
426 conv_to_gemm_transformer.template MakeBDescriptor_N_K<Layout>();
428 const auto wei_gemmn_gemmk_desc =
429 matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
431 return wei_gemmn_gemmk_desc;
434 template <
typename ELay>
438 using Layout = std::conditional_t<
445 const auto out_gemmmraw_gemmnraw_desc =
446 conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
449 constexpr auto matrix_padder_trans =
455 return matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
485 using GemmADataType = std::conditional_t<!isMultiA && isMultiB, Tuple<ADataType>, ADataType>;
486 using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
488#define GridwiseGemmMultiABDTemplateParameters \
489 GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
490 EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
491 InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
492 KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, \
493 ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
494 ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
495 ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
496 ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
497 BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
498 BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
499 BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
500 CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
501 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
502 CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
505#define GridwiseGemmTemplateParameters \
506 GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
507 EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
508 NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \
509 NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
510 ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
511 ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
512 ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \
513 BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \
514 BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \
515 BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \
516 BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
517 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
518 CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
519 BComputeDataType, DoElementwiseBeforeCShuffle
521#define GridwiseGemmCTransposeTemplateParameters \
522 GemmBDataType, GemmADataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
523 EDataType, BElementwiseOperation, AElementwiseOperation, CDEElementwiseOperation, \
524 NumGemmKPrefetchStage, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, \
525 MPerXDL, NXdlPerWave_, MXdlPerWave, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
526 BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
527 BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
528 BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
529 ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
530 ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
531 ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
532 ABlockLdsExtraM, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
533 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
534 CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
535 BComputeDataType, DoElementwiseBeforeCShuffle
538 template <index_t NXdlPerWave_>
541 template <index_t NXdlPerWave_>
544 template <index_t NXdlPerWave_>
567 std::conditional_t<isMultiA, std::array<const void*, NumATensor>&,
const void*>;
569 std::conditional_t<isMultiB, std::array<const void*, NumBTensor>&,
const void*>;
573 decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm64, ADataType > ())>;
575 decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm64, BDataType > ())>;
579 remove_cvref_t<
decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(
582 remove_cvref_t<
decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(
585 decltype(GridwiseGemmCTranspose64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
588 decltype(GridwiseGemmCTranspose64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
593 remove_cvref_t<
decltype(GridwiseGemmCTranspose64::MakeDefaultBlock2ETileMap(
599 .template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
602 .template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
606 .template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
609 .template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
675 template <
typename Gr
idwiseGemm,
typename Gr
idwiseGemmCTranspose>
681 const auto as_grid_desc_ak0_m_ak1 =
683 const auto bs_grid_desc_bk0_n_bk1 =
686 if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1,
687 bs_grid_desc_bk0_n_bk1,
693 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
697 GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
733 const std::array<const void*, NumDTensor>& p_ds,
735 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
736 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
737 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
738 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
739 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>&
740 ds_g_n_k_wos_lengths,
741 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>&
742 ds_g_n_k_wos_strides,
743 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
744 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
745 const std::array<index_t, NDimSpatial>& conv_filter_strides,
746 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
747 const std::array<index_t, NDimSpatial>& input_left_pads,
748 const std::array<index_t, NDimSpatial>& input_right_pads,
749 const AElementwiseOperation& a_element_op,
750 const BElementwiseOperation& b_element_op,
751 const CDEElementwiseOperation& cde_element_op)
759 a_g_n_c_wis_lengths, a_g_n_c_wis_strides)
760 : a_g_n_c_wis_strides},
764 b_g_k_c_xs_lengths, b_g_k_c_xs_strides)
765 : b_g_k_c_xs_strides},
771 e_g_n_k_wos_lengths, e_g_n_k_wos_strides)
772 : e_g_n_k_wos_strides},
827 p_as_grid_(i) =
static_cast<const DataType*
>(p_as[i.value]);
837 p_as_grid_(i) =
static_cast<const DataType*
>(p_as);
854 p_bs_grid_(i) =
static_cast<const DataType*
>(p_bs[i.value]);
859 p_bs_grid_(i) =
static_cast<const DataType*
>(p_bs);
883 p_ds_grid_(i) =
static_cast<const DDataType*
>(p_ds[i]);
929 a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
932 a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
936 b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
939 b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
943 e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
946 e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
993 return sizeof(EDataType) * e_accum;
1016 [&](
auto i) { std::cout <<
"Ds[M, N]: " <<
ds_grid_desc_m_n_[i] << std::endl; });
1076 ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>
1091 template <
typename Gr
idwiseGemm,
typename Gr
idwiseGemmCTranspose>
1094 if(stream_config.log_level_ > 0)
1099 const index_t num_workgroups_per_Conv_N =
1104 const index_t gdz = num_workgroups_per_Conv_N;
1109 auto launch_kernel = [&](
auto has_main_k_block_loop) {
1110 constexpr bool has_main_loop = has_main_k_block_loop.value;
1120 const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
1124 typename GridwiseGemm::DsGridPointer,
1126 AElementwiseOperation,
1127 BElementwiseOperation,
1128 CDEElementwiseOperation,
1129 decltype(as_grid_desc_ak0_m_ak1),
1130 decltype(bs_grid_desc_bk0_n_bk1),
1134 ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
1135 ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
1144 dim3(gdx, gdy, gdz),
1154 as_grid_desc_ak0_m_ak1,
1155 bs_grid_desc_bk0_n_bk1,
1193 const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
1194 GridwiseGemmCTranspose,
1197 typename GridwiseGemm::DsGridPointer,
1199 BElementwiseOperation,
1200 AElementwiseOperation,
1201 CDEElementwiseOperation,
1207 ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
1208 ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
1216 dim3(gdx, gdy, gdz),
1236 const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
1240 typename GridwiseGemm::DsGridPointer,
1242 AElementwiseOperation,
1243 BElementwiseOperation,
1244 CDEElementwiseOperation,
1250 ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
1251 ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
1260 dim3(gdx, gdy, gdz),
1281 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
1291 template <
typename Gr
idwiseGemm,
typename Gr
idwiseGemmCTranspose>
1294 float avg_time = 0.f;
1327 dim3(a_grid_size + b_grid_size),
1344 avg_time += RunGemm<GridwiseGemm, GridwiseGemmCTranspose>(arg, stream_config);
1352 const EDataType* p_e_in_grid =
1357 EDataType* p_e_out_grid = arg.
p_e_grid_;
1389 return RunImp<GridwiseGemm64, GridwiseGemmCTranspose64>(arg, stream_config);
1400 return RunImp<GridwiseGemm32, GridwiseGemmCTranspose32>(arg, stream_config);
1411 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
1419 const index_t G = arg.b_g_k_c_xs_lengths_[
I0];
1420 const index_t K = arg.b_g_k_c_xs_lengths_[
I1];
1421 const index_t C = arg.b_g_k_c_xs_lengths_[
I2];
1423 arg.a_g_n_c_wis_lengths_.begin() +
I3, NDimSpatial, 1, std::multiplies<>());
1441 if constexpr(ConvForwardSpecialization ==
1445 for(
index_t i = 0; i < NDimSpatial; ++i)
1447 const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3];
1448 const index_t ConvStride = arg.conv_filter_strides_[i];
1452 if(!(SpatialDim == 1 && ConvStride == 1 &&
LeftPad == 0 &&
RightPad == 0))
1458 else if constexpr(ConvForwardSpecialization ==
1462 for(
index_t i = 0; i < NDimSpatial; ++i)
1464 const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3];
1480 for(
index_t i = 0; i < NDimSpatial; ++i)
1482 const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i +
I3];
1484 if(filter_spatial_dim !=
I3)
1491 if constexpr(NumGroupsToMerge > 1)
1497 if(G % NumGroupsToMerge != 0)
1520 if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
1523 if(!(ABlockTransferSrcVectorDim == 1 && (C == 1 || NumGroupsToMerge == 1) &&
1527 G % ABlockTransferSrcScalarPerVector == 0))
1536 static_assert(NumGroupsToMerge == 1);
1538 if constexpr(ABlockTransferSrcScalarPerVector != 1)
1540 if(ABlockTransferSrcVectorDim != 1)
1544 if(input_spatial_acum % ABlockTransferSrcScalarPerVector != 0)
1565 if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
1587 if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
1595 if(arg.ds_g_n_k_wos_lengths_[i][0] != arg.e_g_n_k_wos_lengths_[0] ||
1596 arg.ds_g_n_k_wos_lengths_[i][2] != arg.e_g_n_k_wos_lengths_[2])
1604 for(
index_t d = 0; d < NDimSpatial + 3; d++)
1606 if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d])
1621 if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1626 if((G * K) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1632 arg.e_g_n_k_wos_lengths_.begin() +
I3, NDimSpatial, 1, std::multiplies<>());
1634 if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1639 if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1644 if(!arg.p_workspace_)
1648 std::cout <<
"Warning: Workspace for "
1649 "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument is not "
1650 "allocated, use SetWorkSpacePointer."
1657 if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() *
sizeof(ADataType) <= TwoGB &&
1658 arg.e_in_transpose_desc_.GetElementSpaceSize() *
sizeof(EDataType) <= TwoGB))
1679 if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
1687 arg.e_g_n_k_wos_lengths_.begin() +
I3, NDimSpatial, 1, std::multiplies<>());
1689 if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1710 std::cout <<
"ComputeDataType for A and B should be same while using TF32"
1728 return GridwiseGemm64::CheckValidity(as_grid_desc_ak0_m_ak1,
1729 bs_grid_desc_bk0_n_bk1,
1730 arg.ds_grid_desc_m_n_,
1731 arg.e_grid_desc_m_n_,
1732 arg.block_2_etile_map_);
1738 return GridwiseGemmCTranspose64::CheckValidity(arg.b_grid_desc_n_k_,
1739 arg.a_grid_desc_m_k_,
1740 arg.ds_grid_desc_m_n_,
1741 arg.e_grid_desc_m_n_,
1742 arg.block_2_etile_map_);
1746 return GridwiseGemmCTranspose64::CheckValidity(arg.a_grid_desc_m_k_,
1747 arg.b_grid_desc_n_k_,
1748 arg.ds_grid_desc_m_n_,
1749 arg.e_grid_desc_m_n_,
1750 arg.block_2_etile_map_);
1767 return GridwiseGemm32::CheckValidity(as_grid_desc_ak0_m_ak1,
1768 bs_grid_desc_bk0_n_bk1,
1769 arg.ds_grid_desc_m_n_,
1770 arg.e_grid_desc_m_n_,
1771 arg.block_2_etile_map_);
1777 return GridwiseGemmCTranspose32::CheckValidity(arg.b_grid_desc_n_k_,
1778 arg.a_grid_desc_m_k_,
1779 arg.ds_grid_desc_m_n_,
1780 arg.e_grid_desc_m_n_,
1781 arg.block_2_etile_map_);
1785 return GridwiseGemmCTranspose32::CheckValidity(arg.a_grid_desc_m_k_,
1786 arg.b_grid_desc_n_k_,
1787 arg.ds_grid_desc_m_n_,
1788 arg.e_grid_desc_m_n_,
1789 arg.block_2_etile_map_);
1806 const std::array<const void*, NumDTensor>& p_ds,
1808 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1809 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1810 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1811 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1812 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>& ds_g_n_k_wos_lengths,
1813 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>& ds_g_n_k_wos_strides,
1814 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1815 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1816 const std::array<index_t, NDimSpatial>& conv_filter_strides,
1817 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
1818 const std::array<index_t, NDimSpatial>& input_left_pads,
1819 const std::array<index_t, NDimSpatial>& input_right_pads,
1820 const AElementwiseOperation& a_element_op,
1821 const BElementwiseOperation& b_element_op,
1822 const CDEElementwiseOperation& cde_element_op)
1824 return Argument{p_as,
1828 a_g_n_c_wis_lengths,
1829 a_g_n_c_wis_strides,
1832 ds_g_n_k_wos_lengths,
1833 ds_g_n_k_wos_strides,
1834 e_g_n_k_wos_lengths,
1835 e_g_n_k_wos_strides,
1836 conv_filter_strides,
1837 conv_filter_dilations,
1848 const std::array<const void*, NumDTensor>& p_ds,
1850 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1851 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1852 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1853 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1854 const std::array<std::array<long_index_t, NDimSpatial + 3>,
NumDTensor>&
1855 ds_g_n_k_wos_lengths,
1856 const std::array<std::array<long_index_t, NDimSpatial + 3>,
NumDTensor>&
1857 ds_g_n_k_wos_strides,
1858 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1859 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1860 const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
1861 const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
1862 const std::array<long_index_t, NDimSpatial>& input_left_pads,
1863 const std::array<long_index_t, NDimSpatial>& input_right_pads,
1864 const AElementwiseOperation& a_element_op,
1865 const BElementwiseOperation& b_element_op,
1866 const CDEElementwiseOperation& cde_element_op)
1868 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
1869 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
1870 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
1871 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
1872 std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor> ds_g_n_k_wos_lengths_i32;
1873 std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor> ds_g_n_k_wos_strides_i32;
1874 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
1875 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
1876 std::array<index_t, NDimSpatial> conv_filter_strides_i32;
1877 std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
1878 std::array<index_t, NDimSpatial> input_left_pads_i32;
1879 std::array<index_t, NDimSpatial> input_right_pads_i32;
1881 array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
1882 array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
1887 array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
1888 array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
1890 array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
1891 array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
1892 array_convert(conv_filter_strides_i32, conv_filter_strides);
1893 array_convert(conv_filter_dilations_i32, conv_filter_dilations);
1897 return Argument{p_as,
1901 a_g_n_c_wis_lengths_i32,
1902 a_g_n_c_wis_strides_i32,
1903 b_g_k_c_xs_lengths_i32,
1904 b_g_k_c_xs_strides_i32,
1905 ds_g_n_k_wos_lengths_i32,
1906 ds_g_n_k_wos_strides_i32,
1907 e_g_n_k_wos_lengths_i32,
1908 e_g_n_k_wos_strides_i32,
1909 conv_filter_strides_i32,
1910 conv_filter_dilations_i32,
1911 input_left_pads_i32,
1912 input_right_pads_i32,
1923 const std::array<const void*, NumDTensor>& p_ds,
1925 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1926 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1927 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1928 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1929 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>& ds_g_n_k_wos_lengths,
1930 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>& ds_g_n_k_wos_strides,
1931 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1932 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1933 const std::array<index_t, NDimSpatial>& conv_filter_strides,
1934 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
1935 const std::array<index_t, NDimSpatial>& input_left_pads,
1936 const std::array<index_t, NDimSpatial>& input_right_pads,
1937 const AElementwiseOperation& a_element_op,
1938 const BElementwiseOperation& b_element_op,
1939 const CDEElementwiseOperation& cde_element_op)
override
1941 return std::make_unique<Argument>(p_as,
1945 a_g_n_c_wis_lengths,
1946 a_g_n_c_wis_strides,
1949 ds_g_n_k_wos_lengths,
1950 ds_g_n_k_wos_strides,
1951 e_g_n_k_wos_lengths,
1952 e_g_n_k_wos_strides,
1953 conv_filter_strides,
1954 conv_filter_dilations,
1962 std::unique_ptr<BaseArgument>
1965 const std::array<const void*, NumDTensor>& p_ds,
1967 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1968 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1969 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1970 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1971 const std::array<std::array<long_index_t, NDimSpatial + 3>,
NumDTensor>&
1972 ds_g_n_k_wos_lengths,
1973 const std::array<std::array<long_index_t, NDimSpatial + 3>,
NumDTensor>&
1974 ds_g_n_k_wos_strides,
1975 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1976 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1977 const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
1978 const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
1979 const std::array<long_index_t, NDimSpatial>& input_left_pads,
1980 const std::array<long_index_t, NDimSpatial>& input_right_pads,
1981 const AElementwiseOperation& a_element_op,
1982 const BElementwiseOperation& b_element_op,
1983 const CDEElementwiseOperation& cde_element_op)
override
1986 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
1987 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
1988 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
1989 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
1990 std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor> ds_g_n_k_wos_lengths_i32;
1991 std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor> ds_g_n_k_wos_strides_i32;
1992 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
1993 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
1994 std::array<index_t, NDimSpatial> conv_filter_strides_i32;
1995 std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
1996 std::array<index_t, NDimSpatial> input_left_pads_i32;
1997 std::array<index_t, NDimSpatial> input_right_pads_i32;
1999 array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
2000 array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
2005 array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
2006 array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
2008 array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
2009 array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
2010 array_convert(conv_filter_strides_i32, conv_filter_strides);
2011 array_convert(conv_filter_dilations_i32, conv_filter_dilations);
2015 return std::make_unique<Argument>(p_as,
2019 a_g_n_c_wis_lengths_i32,
2020 a_g_n_c_wis_strides_i32,
2021 b_g_k_c_xs_lengths_i32,
2022 b_g_k_c_xs_strides_i32,
2023 ds_g_n_k_wos_lengths_i32,
2024 ds_g_n_k_wos_strides_i32,
2025 e_g_n_k_wos_lengths_i32,
2026 e_g_n_k_wos_strides_i32,
2027 conv_filter_strides_i32,
2028 conv_filter_dilations_i32,
2029 input_left_pads_i32,
2030 input_right_pads_i32,
2038 return std::make_unique<Invoker>(Invoker{});
2043 auto str = std::stringstream();
2046 str <<
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
2048 << BlockSize <<
", "
2049 << MPerBlock <<
", "
2050 << NPerBlock <<
", "
2051 << KPerBlock <<
", "
2055 << MXdlPerWave <<
", "
2056 << NXdlPerWave <<
", "
2057 << ABlockTransferSrcScalarPerVector <<
", "
2058 << BBlockTransferSrcScalarPerVector <<
", "
2059 << CDEBlockTransferScalarPerVector_NPerBlock <<
", "
2060 << CShuffleMXdlPerWavePerShuffle <<
", "
2061 << CShuffleNXdlPerWavePerShuffle <<
", "
2069#ifdef CK_EXPERIMENTAL_BUILDER
2072 static_assert(ck_tile::reflect::HasInstanceTraits<DeviceOp>,
2073 "Specialization of instance_traits not found. Please check that a "
2074 "specialization exists in file "
2075 "ck_tile/builder/reflect/"
2076 "instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp "
2077 "for the given template parameters.");
2078 return ck_tile::reflect::instance_string<DeviceOp>();
2084 auto arg =
dynamic_cast<const Argument*
>(p_arg);
2090 throw std::runtime_error(
2091 "The argument pointer is not an object of "
2092 "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
2099 auto p_arg_ =
dynamic_cast<Argument*
>(p_arg);
2102 p_arg_->p_workspace_ = p_workspace;
2105 throw std::runtime_error(
2106 "The argument pointer is not an object of "
2107 "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition tensor_operation/gpu/device/tensor_layout.hpp:42
Definition convolution_backward_data_specialization.hpp:8
constexpr bool is_NSpatialGC_GKSpatial_NSpatialGK()
Definition device_grouped_conv_utils.hpp:119
constexpr bool is_NGCSpatial_GKSpatial_NGKSpatial()
Definition device_grouped_conv_utils.hpp:135
GemmSpecialization
Definition gemm_specialization.hpp:11
constexpr bool is_NGCHW_GKYXC_NGKHW()
Definition device_grouped_conv_utils.hpp:56
constexpr bool is_NGCDHW_NGKDHW()
Definition device_grouped_conv_utils.hpp:112
constexpr bool is_NGCHW_GKCYX_NGKHW()
Definition device_grouped_conv_utils.hpp:64
decltype(std::declval< T & >().IsTuple()) is_tuple
Definition device_grouped_conv_fwd_multiple_abd.hpp:23
constexpr bool is_NGCDHW_GKZYXC_NGKDHW()
Definition device_grouped_conv_utils.hpp:96
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ Filter1x1Stride1Pad0
Definition convolution_forward_specialization.hpp:18
@ Filter3x3
Definition convolution_forward_specialization.hpp:20
@ Filter1x1Pad0
Definition convolution_forward_specialization.hpp:17
std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization &s)
Definition convolution_forward_specialization.hpp:24
constexpr bool is_NGCDHW_GKCZYX_NGKDHW()
Definition device_grouped_conv_utils.hpp:104
constexpr bool is_NGCHW_NGKHW()
Definition device_grouped_conv_utils.hpp:72
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
int32_t index_t
Definition ck.hpp:299
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
const char * get_type_name()
Definition data_type.hpp:468
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
bool is_tf32_supported()
Definition host_utility/device_prop.hpp:132
__host__ __device__ void array_convert(std::array< Y, NumElems > &y, const std::array< X, NumElems > &x)
Definition utility/type_convert.hpp:2466
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
auto accumulate_n(ForwardIterator first, Size count, T init, BinaryOperation op) -> decltype(std::accumulate(first, std::next(first, count), init, op))
Definition library/utility/numeric.hpp:11
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__global__ void kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, const InADataTypePointerTuple p_in_global_tuple_a, const InBDataTypePointerTuple p_in_global_tuple_b, const OutADataTypePointerTuple p_out_global_tuple_a, const OutBDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, const index_t a_grid_size)
Definition gridwise_elementwise_2d.hpp:61
int64_t long_index_t
Definition ck.hpp:300
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition gridwise_elementwise_2d.hpp:29
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Layout wrapper that performs the tensor descriptor logic.
Definition layout.hpp:24
Definition ck/stream_config.hpp:10
Definition block_to_ctile_map.hpp:261
Definition gridwise_elementwise_2d.hpp:278
Definition gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:77
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
Definition multi_index_transform.hpp:196
Definition multi_index_transform.hpp:284
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:117
Definition functional2.hpp:33
Definition device_base.hpp:197
void * p_workspace_
Definition device_base.hpp:204
virtual std::string GetInstanceString() const
Definition device_base.hpp:230
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:674
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_lengths_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1035
Argument(APointers p_as, BPointers p_bs, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:731
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_strides_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1036
EGridDesc_M_N e_grid_desc_m_n_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1056
std::array< index_t, NDimSpatial+3 > e_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1040
NGCHWTransposeDescType a_in_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1070
std::array< index_t, NDimSpatial > input_right_pads_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1044
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1060
AElementwiseOperation a_element_op_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1081
ComputePtrOffsetOfStridedBatch< NumATensor, I1, NumDTensor > compute_ptr_offset_of_n_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1078
void Print() const
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1007
std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1038
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_e_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1068
BGridPointer p_bs_grid_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1028
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_strides_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1034
CDEElementwiseOperation cde_element_op_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1083
std::size_t GetWorkspaceETensorSizeBytes() const
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:987
NHWGCTransposeDescType e_in_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1071
AGridDesc_M_K a_grid_desc_m_k_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1053
std::size_t GetWorkspaceSizeBytes() const
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1001
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_lengths_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1033
GKYXCTransposeDescType b_out_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1073
ComputePtrOffsetOfStridedBatch< NumATensor, NumBTensor, NumDTensor > compute_ptr_offset_of_groups_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1077
index_t num_group_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1047
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1059
BElementwiseOperation b_element_op_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1082
std::array< index_t, NDimSpatial > conv_filter_strides_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1041
AGridPointer p_as_grid_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1027
NGCHWTransposeDescType e_out_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1070
std::size_t GetWorkspaceBTensorSizeBytes() const
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:972
void InitGridDesc()
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:676
GridwiseGemm64::DsGridPointer p_ds_grid_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1029
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1063
Block2ETileMap block_2_etile_map_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1066
GKCYXTransposeDescType b_in_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1072
NHWGCTransposeDescType a_out_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1071
std::array< index_t, NDimSpatial > input_left_pads_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1043
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1055
BGridDesc_N_K b_grid_desc_n_k_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1054
std::array< index_t, NDimSpatial > conv_filter_dilations_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1042
index_t conv_N_per_block_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1051
std::array< index_t, NDimSpatial+3 > e_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1039
ConvToGemmFwdTransformer conv_to_gemm_transformer_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1049
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1062
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_b_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1068
std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1037
EDataType * p_e_grid_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1030
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1067
std::size_t GetWorkspaceATensorSizeBytes() const
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:957
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1088
DeviceOp::Argument Argument
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1089
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1292
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1383
float RunGemm(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1092
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1408
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:325
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1415
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::I0 static constexpr auto I0
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:350
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::NXdlPerWave64 static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:328
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::NeedTransposeKernel static constexpr bool NeedTransposeKernel
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:363
BlockToCTileMap_M00_N0_M01Adapt< NPerBlock, NPerBlock > Block2TileMapElementwise
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:595
std::unique_ptr< BaseArgument > MakeArgumentPointer(APointers p_as, BPointers p_bs, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_dilations, const std::array< long_index_t, NDimSpatial > &input_left_pads, const std::array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op) override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1963
remove_cvref_t< decltype(MakeBGridDescriptor_N_K< BLayout >(dummy_conv_to_gemm_transformer))> BGridDesc_N_K
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:476
remove_cvref_t< decltype(GridwiseGemmCTranspose64::MakeDefaultBlock2ETileMap( EGridDesc_M_N{}))> Block2ETileMap
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:592
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::I4 static constexpr auto I4
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:354
std::conditional_t<!isMultiA &&isMultiB, Tuple< ADataType >, ADataType > GemmADataType
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:485
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::conv_ngchw_to_nhwgc_transformer static constexpr auto conv_ngchw_to_nhwgc_transformer
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:383
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeNGCHWTransposeDesc< NDimSpatial >({}, {}))> NGCHWTransposeDescType
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:597
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::NumDTensor static constexpr index_t NumDTensor
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:344
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::NumATensor static constexpr index_t NumATensor
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:342
remove_cvref_t< decltype(GetBGridPointer< isMultiA||isMultiB, GridwiseGemm64, BDataType >())> BGridPointer
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:574
static auto MakeArgument(APointers p_as, BPointers p_bs, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_dilations, const std::array< long_index_t, NDimSpatial > &input_left_pads, const std::array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1846
TransformConvFwdToGemm< NDimSpatial, ConvForwardSpecialization, true, ADataType, EDataType, NumGroupsToMerge, index_t, CTranspose > ConvToGemmFwdTransformer
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:371
GridwiseElementwise< Tuple< GKCYXTransposeDescType >, Tuple< GKYXCTransposeDescType >, Tuple< const BDataType * >, Tuple< BDataType * >, Block2TileMapElementwise, element_wise::PassThrough, ElementwiseBlocksize, NPerBlock, NPerBlock, NPerBlock/ClusterLengthNPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 1, 0 >, Sequence< 1 >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, I0, I1 > GridwiseElementwiseWeightTranspose
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:636
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:435
remove_cvref_t< decltype(GetAGridPointer< isMultiA||isMultiB, GridwiseGemm64, ADataType >())> AGridPointer
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:572
std::conditional_t<!isMultiB &&isMultiA, Tuple< BDataType >, BDataType > GemmBDataType
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:486
GridwiseGemmMultipleABD_xdl_cshuffle< GridwiseGemmMultiABDTemplateParameters > GridwiseGemmMultipleABDBase
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:539
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::DoElementwiseBeforeCShuffle static constexpr bool DoElementwiseBeforeCShuffle
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:346
std::conditional_t< CTranspose, GridwiseGemmMultipleDCTransposeBase< math::max(NXdlPerWave64, 1)>, GridwiseGemm64 > GridwiseGemmCTranspose64
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:556
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:415
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:2082
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:461
remove_cvref_t< decltype(MakeAGridDescriptor_M_K< ALayout >(dummy_conv_to_gemm_transformer))> AGridDesc_M_K
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:474
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::matrix_padder static constexpr auto matrix_padder
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:391
remove_cvref_t< decltype(GridwiseGemmCTranspose64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{}))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:587
std::conditional_t< isMultiA||isMultiB, GridwiseGemmMultipleABDBase< math::max(NXdlPerWave64, 1)>, GridwiseGemmMultipleDBase< math::max(NXdlPerWave64, 1)> > GridwiseGemm64
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:548
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:2036
GridwiseGemmMultipleD_xdl_cshuffle< GridwiseGemmTemplateParameters > GridwiseGemmMultipleDBase
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:542
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle DeviceOp
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:326
static auto MakeInvoker()
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1918
std::conditional_t< isMultiA, std::array< const void *, NumATensor > &, const void * > APointers
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:566
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::I5 static constexpr auto I5
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:355
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::ClusterLengthNPerBlock static constexpr index_t ClusterLengthNPerBlock
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:380
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::CTranspose static constexpr bool CTranspose
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:367
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::dummy_conv_to_gemm_transformer static constexpr ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:473
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::ElementwiseBlocksize static constexpr index_t ElementwiseBlocksize
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:611
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1( BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:581
std::conditional_t< isMultiA||isMultiB, GridwiseGemmMultipleABDBase< NXdlPerWave32 >, GridwiseGemmMultipleDBase< NXdlPerWave32 > > GridwiseGemm32
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:552
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::isMultiA static constexpr bool isMultiA
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:333
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::NXdlPerWave32 static constexpr auto NXdlPerWave32
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:329
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))> DsGridDesc_M_N
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:478
remove_cvref_t< decltype(MakeEGridDescriptor_M_N< ELayout >(dummy_conv_to_gemm_transformer))> EGridDesc_M_N
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:480
GridwiseElementwise< Tuple< NGCHWTransposeDescType >, Tuple< NHWGCTransposeDescType >, Tuple< const ADataType * >, Tuple< ADataType * >, Block2TileMapElementwise, element_wise::PassThrough, ElementwiseBlocksize, NPerBlock, NPerBlock, NPerBlock/ClusterLengthNPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 1, 0 >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, I1, I0 > GridwiseElementwiseInputTranspose
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:618
static auto MakeArgument(APointers p_as, BPointers p_bs, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1803
GridwiseElementwise< Tuple< NHWGCTransposeDescType >, Tuple< NGCHWTransposeDescType >, Tuple< const EDataType * >, Tuple< EDataType * >, Block2TileMapElementwise, element_wise::PassThrough, ElementwiseBlocksize, NPerBlock, NPerBlock, NPerBlock/ClusterLengthNPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 1, 0 >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, I0, I1 > GridwiseElementwiseOutputTranspose
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:654
std::conditional_t< CTranspose, GridwiseGemmMultipleDCTransposeBase< NXdlPerWave32 >, GridwiseGemm32 > GridwiseGemmCTranspose32
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:560
remove_cvref_t< decltype(GridwiseGemmCTranspose64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:584
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::I3 static constexpr auto I3
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:353
std::unique_ptr< BaseArgument > MakeArgumentPointer(APointers p_as, BPointers p_bs, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op) override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1920
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::I1 static constexpr auto I1
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:351
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeGKCYXTransposeDesc< NDimSpatial >({}, {}))> GKCYXTransposeDescType
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:604
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeGKYXCTransposeDesc< NDimSpatial >({}, {}))> GKYXCTransposeDescType
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:607
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1( AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:578
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::isMultiB static constexpr bool isMultiB
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:334
std::string GetTypeString() const override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:2041
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1798
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:395
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::I2 static constexpr auto I2
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:352
void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:2095
std::conditional_t< isMultiB, std::array< const void *, NumBTensor > &, const void * > BPointers
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:568
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::NumBTensor static constexpr index_t NumBTensor
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:343
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::isMultiAB static constexpr bool isMultiAB
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:335
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >::isATensorColMajor static constexpr bool isATensorColMajor
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:357
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeNHWGCTransposeDesc< NDimSpatial >({}, {}))> NHWGCTransposeDescType
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:600
GridwiseGemmMultipleD_xdl_cshuffle< GridwiseGemmCTransposeTemplateParameters > GridwiseGemmMultipleDCTransposeBase
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:545
Grouped Convolution Forward.
Definition device_grouped_conv_fwd_multiple_abd.hpp:73
__host__ __device__ constexpr auto PadCDescriptor_M_N(const CDesc_MRaw_NRaw &c_desc_mraw_nraw) const
Definition matrix_padder.hpp:163
Definition matrix_padder.hpp:180
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129