32template <
typename ALayout,
39 typename CShuffleDataType,
42 typename AElementwiseOperation,
43 typename BElementwiseOperation,
44 typename CDEElementwiseOperation,
57 typename ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
58 typename ABlockTransferThreadClusterArrangeOrder,
59 typename ABlockTransferSrcAccessOrder,
60 index_t ABlockTransferSrcVectorDim,
61 index_t ABlockTransferSrcScalarPerVector,
62 index_t ABlockTransferDstScalarPerVector_AK1,
64 typename BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
65 typename BBlockTransferThreadClusterArrangeOrder,
66 typename BBlockTransferSrcAccessOrder,
67 index_t BBlockTransferSrcVectorDim,
68 index_t BBlockTransferSrcScalarPerVector,
69 index_t BBlockTransferDstScalarPerVector_BK1,
71 index_t CShuffleMXdlPerWavePerShuffle,
72 index_t CShuffleNXdlPerWavePerShuffle,
73 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
74 index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
77 typename ComputeDataType = EDataType,
89 AElementwiseOperation,
90 BElementwiseOperation,
91 CDEElementwiseOperation>
111 template <index_t NXdlPerWave_>
121 AElementwiseOperation,
122 BElementwiseOperation,
125 NumGemmKPrefetchStage,
134 ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
135 ABlockTransferThreadClusterArrangeOrder,
136 ABlockTransferSrcAccessOrder,
137 ABlockTransferSrcVectorDim,
138 ABlockTransferSrcScalarPerVector,
139 ABlockTransferDstScalarPerVector_AK1,
142 BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
143 BBlockTransferThreadClusterArrangeOrder,
144 BBlockTransferSrcAccessOrder,
145 BBlockTransferSrcVectorDim,
146 BBlockTransferSrcScalarPerVector,
147 BBlockTransferDstScalarPerVector_BK1,
150 CShuffleMXdlPerWavePerShuffle,
151 CShuffleNXdlPerWavePerShuffle,
152 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
153 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
159 template <
typename ELay>
162 const auto c_grid_desc_m_n = [&]() {
175 const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
176 const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
196 const std::array<index_t, NumDTensor>& NRaws,
197 const std::array<index_t, NumDTensor>& DsStride)
214 return static_cast<const DDataType*
>(
nullptr);
222 [&]([[maybe_unused]]
auto i)
constexpr {
238 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
240 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
251 CDEElementwiseOperation,
294 std::vector<const void*>& p_Bs,
295 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
296 std::vector<void*>& p_Es,
297 std::vector<GemmDesc>& gemm_descs,
298 AElementwiseOperation a_element_op,
299 BElementwiseOperation b_element_op,
300 CDEElementwiseOperation cde_element_op)
314 std::vector<const void*>& p_Bs,
315 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
316 std::vector<void*>& p_Es,
317 std::vector<GemmDesc>& gemm_descs,
318 AElementwiseOperation a_element_op,
319 BElementwiseOperation b_element_op,
320 CDEElementwiseOperation cde_element_op,
337 throw std::runtime_error(
"Error! group_count_ != p_As/Bs/Ds/Es size");
347 for(std::size_t i = 0; i < gemm_descs.size(); ++i)
349 const index_t M = gemm_descs[i].M_;
350 const index_t N = gemm_descs[i].N_;
351 const index_t K = gemm_descs[i].K_;
353 if(M == 0 || N == 0 || K == 0)
359 const index_t stride_a = gemm_descs[i].stride_A_;
360 const index_t stride_b = gemm_descs[i].stride_B_;
361 const index_t stride_e = gemm_descs[i].stride_C_;
368 const auto c_grid_desc_m_n =
378 p_ds_grid(j) =
static_cast<const DDataType*
>(p_Ds[i][j]);
382 const auto local_b2c_tile_map =
384 const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
392 auto grouped_block_2_ctile_map =
395 std::array<index_t, NumDTensor> stride_ds;
400 throw std::runtime_error(
401 "Error! gemm_descs[i].stride_Ds_.size() does not match NumDTensor");
404 stride_ds[j] = gemm_descs[i].stride_Ds_[j];
406 stride_Ds_.emplace_back(std::move(stride_ds));
426 std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
453 const auto c_grid_desc_m_n =
456 const auto local_b2c_tile_map =
458 const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
466 auto grouped_block_2_ctile_map =
470 karg.KPadded = k_padded;
471 karg.K0Padded = k0_padded;
480 std::cout <<
"block_start: " << block_start <<
"\n"
481 <<
"block_end: " << block_end <<
"\n"
482 <<
"tiles: " << tiles << std::endl
485 std::cout <<
"KPadded: " << karg.KPadded << std::endl
486 <<
"K0Padded: " << karg.K0Padded << std::endl
487 <<
"KBatch: " << karg.k_batch << std::endl
488 <<
"grid_size_: " << karg.KPadded << std::endl;
497 std::size_t offset = 0;
501 arg.karg_.p_c_grid = p_workspace + offset;
502 index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch;
503 offset += tiles * MPerBlock * NPerBlock;
506 std::cout <<
"block_start: " << arg.block_start_ <<
"\n"
507 <<
"block_end: " << arg.block_end_ <<
"\n"
508 <<
"tiles: " << tiles <<
"\n"
509 <<
"offset: " << offset << std::endl;
516 std::size_t size_bytes{0};
520 index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch;
529 index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch;
530 return tiles * MPerBlock * NPerBlock;
545 std::vector<std::array<const void*, NumDTensor>>&
p_Ds_;
574 template <
typename Gr
idwiseGemm>
577 void* dev_gemm_workspace,
580 auto [all_have_kbatch_gt_one, all_have_main_k_block_loop] =
581 CheckArgument<GridwiseGemm>(arg, stream_config);
583 if(dev_gemm_args ==
nullptr)
585 std::ostringstream err;
586 err <<
"The gemm arguments device buffer is not allocated!" <<
" In " << __FILE__
587 <<
":" << __LINE__ <<
", in function: " << __func__;
588 throw std::runtime_error(err.str());
591 if(dev_gemm_workspace ==
nullptr)
593 std::ostringstream err;
594 err <<
"The gemm workspace buffer is not allocated!" <<
" In " << __FILE__ <<
":"
595 << __LINE__ <<
", in function: " << __func__;
596 throw std::runtime_error(err.str());
601 if(all_have_main_k_block_loop)
603 ave_time = DispatchKernel<GridwiseGemm, true>(
604 arg, dev_gemm_args, dev_gemm_workspace, stream_config);
608 ave_time = DispatchKernel<GridwiseGemm, false>(
609 arg, dev_gemm_args, dev_gemm_workspace, stream_config);
629 template <
typename Gr
idwiseGemm>
634 std::ostringstream err;
635 err <<
"The gemm arguments device buffer is not allocated!" <<
" In " << __FILE__
636 <<
":" << __LINE__ <<
", in function: " << __func__;
637 throw std::runtime_error(err.str());
642 std::ostringstream err;
643 err <<
"The gemm workspace buffer is not allocated!" <<
" In " << __FILE__ <<
":"
644 << __LINE__ <<
", in function: " << __func__;
645 throw std::runtime_error(err.str());
656 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
660 template <
typename Gr
idwiseGemm>
661 auto CheckArgument(
const Argument& arg,
const StreamConfig& stream_config)
const
663 bool all_have_kbatch_gt_one, all_have_main_k_block_loop;
666 const auto a_grid_desc_kbatch_ak0_m_ak1 =
667 GridwiseGemm::MakeAGridDescriptor_KBatch_K0_M_K1(
668 arg.gemm_kernel_args_[0].karg_.M,
669 arg.gemm_kernel_args_[0].karg_.MPadded,
670 arg.gemm_kernel_args_[0].karg_.K,
671 arg.gemm_kernel_args_[0].karg_.StrideA,
672 arg.gemm_kernel_args_[0].karg_.k_batch,
673 arg.gemm_kernel_args_[0].karg_.K0Padded,
674 arg.gemm_kernel_args_[0].karg_.KPadded);
676 all_have_kbatch_gt_one = arg.K_BATCH > 1;
677 all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(
678 a_grid_desc_kbatch_ak0_m_ak1.GetLength(
I1) *
679 a_grid_desc_kbatch_ak0_m_ak1.GetLength(
I3));
682 for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
684 const auto& gemm_arg =
reinterpret_cast<const typename GridwiseGemm::Argument&
>(
685 arg.gemm_kernel_args_[i].karg_);
691 if(!GridwiseGemm::CheckValidity(gemm_arg))
693 std::ostringstream err;
694 err <<
"Group id: " << i <<
" has invalid GridwiseGemm settings!" << __FILE__
695 <<
":" << __LINE__ <<
", in function: " << __func__;
696 throw std::runtime_error(err.str());
699 const auto a_grid_desc_kbatch_ak0_m_ak1 =
700 GridwiseGemm::MakeAGridDescriptor_KBatch_K0_M_K1(gemm_arg.M,
708 bool not_all_have_main_k_block_loop_same =
709 all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(
710 a_grid_desc_kbatch_ak0_m_ak1.GetLength(
I1) *
711 a_grid_desc_kbatch_ak0_m_ak1.GetLength(
I3));
712 bool not_all_have_kbatch_value_same =
713 all_have_kbatch_gt_one xor (gemm_arg.k_batch > 1);
715 if(not_all_have_main_k_block_loop_same)
717 std::ostringstream err;
718 err <<
"Not all gemms have same value for main_k0_block_loop! in " << __FILE__
719 <<
":" << __LINE__ <<
", in function: " << __func__;
720 throw std::runtime_error(err.str());
723 if(not_all_have_kbatch_value_same)
725 std::ostringstream err;
726 err <<
"Not all gemms have same kbatch value (=1 or >1)! " <<
"group [" << i
727 <<
"], kbatch: " << gemm_arg.k_batch
728 <<
", group [0], kbatch: " << gemm_arg.k_batch <<
" in " << __FILE__ <<
":"
729 << __LINE__ <<
", in function: " << __func__;
730 throw std::runtime_error(err.str());
733 return std::make_tuple(all_have_kbatch_gt_one, all_have_main_k_block_loop);
736 template <
typename Gr
idwiseGemm,
bool HasMainKBlockLoop>
737 float DispatchKernel(
const Argument& arg,
738 void* dev_gemm_kargs,
739 void* dev_gemm_workspace,
740 const StreamConfig& stream_config)
const
742 const auto gemm_kernel =
747 AElementwiseOperation,
748 BElementwiseOperation,
753 ck::Tuple<EGridDesc_M_N>,
755 ck::Tuple<EDataType*>,
757 CDEElementwiseOperation>;
758 return LaunchKernel(gemm_kernel,
766 template <
typename KernelFunction,
typename KernelFunction2>
767 float LaunchKernel(
const KernelFunction& gemm_kernel,
768 const KernelFunction2& elementwise_kernel,
770 void* dev_gemm_kargs,
771 [[maybe_unused]]
void* dev_gemm_workspace,
772 const StreamConfig& stream_config)
const
777 hipMemcpyAsync(dev_gemm_kargs,
778 arg.gemm_kernel_args_.data(),
779 arg.gemm_kernel_args_.size() *
sizeof(GemmTransKernelArg),
780 hipMemcpyHostToDevice,
783 auto preprocess = [&]() {
785 dev_gemm_workspace, 0, arg.GetWorkspaceSizeBytes(), stream_config.
stream_id_));
793 dim3(arg.grid_size_),
797 arg.gemm_kernel_args_.size(),
803 for(
size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
808 dim3(arg.group_grid_size_[i]),
812 arg.elementwise_d_grid_descs_m_n_[i]),
813 make_tuple(arg.elementwise_c_grid_descs_m_n_[i]),
815 arg.ds_grid_pointer_[i]),
817 Block2TileMap{arg.elementwise_c_grid_descs_m_n_[i].GetLength(I0),
818 arg.elementwise_c_grid_descs_m_n_[i].GetLength(I1)},
819 arg.cde_element_op_);
842 std::cout <<
"The group count is not equal to sum of skipped groups "
843 "and kernel args size!"
849 bool supported =
true;
854 bool group_arg_valid =
false;
867 reinterpret_cast<const typename GridwiseGemm32::Argument&
>(gemm_arg));
871 if(not group_arg_valid)
875 std::cout <<
"[" << __func__ <<
"] group id: " << i
876 <<
" has invalid GridwiseGemm settings!" << std::endl;
880 supported = supported && group_arg_valid;
891 std::vector<const void*>& p_Bs,
892 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
893 std::vector<void*>& p_Es,
894 std::vector<GemmDesc> gemm_descs,
895 AElementwiseOperation a_elementwise_op,
896 BElementwiseOperation b_elementwise_op,
897 CDEElementwiseOperation cde_elementwise_op)
909 std::unique_ptr<BaseArgument>
911 std::vector<const void*>& p_Bs,
912 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
913 std::vector<void*>& p_Es,
914 std::vector<GemmDesc>& gemm_descs,
915 AElementwiseOperation a_elementwise_op,
916 BElementwiseOperation b_elementwise_op,
917 CDEElementwiseOperation cde_elementwise_op)
override
919 return std::make_unique<Argument>(p_As,
933 return std::make_unique<Invoker>(
Invoker{});
938 auto str = std::stringstream();
941 str <<
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage"
943 << std::string(ALayout::name)[0] <<
","
944 << std::string(BLayout::name)[0] <<
","
945 << std::string(ELayout::name)[0] <<
","
954 << MXdlPerWave <<
", "
955 << NXdlPerWave <<
", "
956 << ABlockTransferSrcScalarPerVector <<
", "
957 << BBlockTransferSrcScalarPerVector <<
", "
958 << CShuffleMXdlPerWavePerShuffle <<
", "
959 << CShuffleNXdlPerWavePerShuffle <<
", "
969 auto arg_ptr =
dynamic_cast<Argument*
>(p_arg);
972 arg_ptr->p_dev_gemm_kargs_ = p_dev_kernel_args;
975 throw std::runtime_error(
976 "The argument pointer is not an object of "
977 "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
982 auto arg =
dynamic_cast<const Argument*
>(p_arg);
988 throw std::runtime_error(
989 "The argument pointer is not an object of "
990 "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
995 auto arg =
dynamic_cast<const Argument*
>(p_arg);
998 return arg->GetWorkspaceSizeBytes();
1001 throw std::runtime_error(
1002 "The argument pointer is not an object of "
1003 "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
1011 auto p_arg_ =
dynamic_cast<Argument*
>(p_arg);
1014 p_arg_->p_workspace_ = p_workspace;
1015 p_arg_->UpdateEPointers();
1018 throw std::runtime_error(
1019 "The argument pointer is not an object of "
1020 "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
1030 auto p_arg_ =
dynamic_cast<Argument*
>(p_arg);
1033 p_arg_->UpdateKBatch(kbatch);
1036 throw std::runtime_error(
1037 "The argument pointer is not an object of "
1038 "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:91
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
__global__ void kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op)
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:38
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MNPadding
Definition gemm_specialization.hpp:17
Definition convolution_backward_data_specialization.hpp:7
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition amd_address_space.hpp:35
__host__ __device__ constexpr auto concat_tuple(const Tuple< X... > &tx, const Tuple< Y... > &ty)
Definition tuple_helper.hpp:52
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition gridwise_elementwise_2d.hpp:29
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
hipStream_t stream_id_
Definition ck/stream_config.hpp:11
int log_level_
Definition ck/stream_config.hpp:13
Definition block_to_ctile_map.hpp:541
Definition block_to_ctile_map.hpp:261
Definition gridwise_gemm_xdlops_v2r4r2.hpp:106
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, WorkspaceDataType, ALayout, BLayout, ELayout, AElementwiseOperation, BElementwiseOperation, PassThrough, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, AK1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEShuffleBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeDataType >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdlops_v2r4r2.hpp:440
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, WorkspaceDataType, ALayout, BLayout, ELayout, AElementwiseOperation, BElementwiseOperation, PassThrough, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, AK1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEShuffleBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeDataType >::CalculateMPadded __host__ static __device__ auto CalculateMPadded(index_t M)
Definition gridwise_gemm_xdlops_v2r4r2.hpp:196
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, WorkspaceDataType, ALayout, BLayout, ELayout, AElementwiseOperation, BElementwiseOperation, PassThrough, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, AK1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEShuffleBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeDataType >::CalculateKPadded __host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdlops_v2r4r2.hpp:213
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, WorkspaceDataType, ALayout, BLayout, ELayout, AElementwiseOperation, BElementwiseOperation, PassThrough, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, AK1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEShuffleBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeDataType >::CalculateK0Padded __host__ static __device__ auto CalculateK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdlops_v2r4r2.hpp:206
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, WorkspaceDataType, ALayout, BLayout, ELayout, AElementwiseOperation, BElementwiseOperation, PassThrough, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, AK1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEShuffleBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeDataType >::CGridDesc_M_N remove_cvref_t< decltype(MakeCGridDescriptor_M_N(1, 1, 1))> CGridDesc_M_N
Definition gridwise_gemm_xdlops_v2r4r2.hpp:661
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, WorkspaceDataType, ALayout, BLayout, ELayout, AElementwiseOperation, BElementwiseOperation, PassThrough, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, AK1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEShuffleBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeDataType >::MakeCGridDescriptor_M_N __host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
Definition gridwise_gemm_xdlops_v2r4r2.hpp:371
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, WorkspaceDataType, ALayout, BLayout, ELayout, AElementwiseOperation, BElementwiseOperation, PassThrough, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, AK1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEShuffleBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeDataType >::CalculateNPadded __host__ static __device__ auto CalculateNPadded(index_t N)
Definition gridwise_gemm_xdlops_v2r4r2.hpp:201
Definition block_to_ctile_map.hpp:872
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition device_base.hpp:197
void * p_workspace_
Definition device_base.hpp:204
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:269
GroupedGemmBlock2ETileMap block_2_ctile_map_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:271
index_t block_start_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:272
index_t block_end_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:272
GemmTransKernelArg()=default
GemmKernelArgument karg_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:270
GemmTransKernelArg(GemmKernelArgument &&karg, GroupedGemmBlock2ETileMap &&b2c_map, index_t block_start, index_t block_end)
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:275
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:291
std::vector< std::array< const void *, NumDTensor > > & p_Ds_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:545
BElementwiseOperation b_element_op_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:542
CDEElementwiseOperation cde_element_op_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:543
std::size_t GetWorkspaceSize(std::size_t group) const
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:526
index_t grid_size_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:537
Argument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, index_t kbatch)
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:313
index_t skipped_group_count_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:536
void UpdateEPointers()
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:493
Argument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:293
void UpdateKBatch(index_t kbatch)
Set new kbatch value.
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:441
std::vector< std::array< index_t, NumDTensor > > stride_Ds_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:546
index_t group_count_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:535
void * p_dev_gemm_kargs_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:539
std::vector< GemmTransKernelArg > gemm_kernel_args_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:547
std::vector< DsGridPointer > ds_grid_pointer_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:552
std::vector< CGridDesc_M_N > elementwise_c_grid_descs_m_n_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:550
AElementwiseOperation a_element_op_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:541
std::vector< void * > e_ptrs_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:553
index_t K_BATCH
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:534
std::vector< DsGridDesc_M_N > elementwise_d_grid_descs_m_n_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:551
std::vector< index_t > group_grid_size_
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:548
std::size_t GetWorkspaceSizeBytes() const
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:514
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:558
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:653
float Run(const Argument &arg, void *dev_gemm_args, void *dev_gemm_workspace, const StreamConfig &stream_config=StreamConfig{})
Launch Grouped Gemm kernel.
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:575
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Launch Grouped Gemm kernel.
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:630
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:92
static constexpr auto NXdlPerWave32
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:96
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:95
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, WorkspaceDataType, ALayout, BLayout, ELayout, AElementwiseOperation, BElementwiseOperation, PassThrough, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, AK1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEShuffleBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeDataType > GridwiseGemmBase
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:112
static constexpr auto MakeDsGridPointer()
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:208
static constexpr index_t NumDTensor
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:98
void SetKBatchSize(BaseArgument *p_arg, index_t kbatch) const override
Sets the k batch size.
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:1028
static auto MakeDsGridDescriptor_M_N(const std::array< index_t, NumDTensor > &MRaws, const std::array< index_t, NumDTensor > &NRaws, const std::array< index_t, NumDTensor > &DsStride)
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:195
static constexpr index_t ClusterLengthMPerBlock
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:237
static constexpr auto I3
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:103
void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args) const override
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:967
decltype(concat_tuple(ck::Tuple< WorkspaceDataType * >{}, DsGridPointer{})) CDDataTypes
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:233
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:885
static constexpr auto MakeElementwiseInputSequence()
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:219
static constexpr auto I2
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:102
typename GridwiseGemm64::Argument GemmKernelArgument
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:266
static auto MakeArgument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > gemm_descs, AElementwiseOperation a_elementwise_op, BElementwiseOperation b_elementwise_op, CDEElementwiseOperation cde_elementwise_op)
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:890
OffsettedBlockToCTileMap< Block2ETileMapKSplit > GroupedGemmBlock2ETileMap
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:265
decltype(MakeElementwiseInputSequence()) ElementwiseInputSequence
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:235
static constexpr index_t K0PerBlock
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:105
static constexpr bool IsValidCompilationParameter()
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:825
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:157
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:931
ck::tensor_operation::element_wise::PassThrough PassThrough
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:107
typename GridwiseGemm64::CGridDesc_M_N EGridDesc_M_N
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:229
decltype(MakeDsGridDescriptor_M_N({}, {}, {})) DsGridDesc_M_N
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:230
size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const override
Gets the device kernel argument size.
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:980
static constexpr index_t ClusterLengthNPerBlock
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:239
static constexpr index_t DefaultKBatch
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:287
float WorkspaceDataType
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:108
BlockToCTileMap_M00_N0_M01Adapt< MPerBlock, NPerBlock > Block2TileMap
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:244
static void SetKBatchSize(Argument &arg, index_t kbatch)
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:1023
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_elementwise_op, BElementwiseOperation b_elementwise_op, CDEElementwiseOperation cde_elementwise_op) override
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:910
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:160
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:158
static constexpr auto I0
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:100
std::string GetTypeString() const override
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:936
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage DeviceOp
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:93
decltype(concat_tuple(ck::Tuple< CGridDesc_M_N >{}, DsGridDesc_M_N{})) CDGridDesc_M_N
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:232
static auto MakeInvoker()
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:929
typename GridwiseGemm64::CGridDesc_M_N CGridDesc_M_N
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:228
void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &stream_config=StreamConfig{}) const override
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:1006
static constexpr auto I1
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:101
GridwiseElementwise< CDGridDesc_M_N, ck::Tuple< EGridDesc_M_N >, CDDataTypes, ck::Tuple< EDataType * >, Block2TileMap, CDEElementwiseOperation, BlockSize, MPerBlock, NPerBlock, MPerBlock/ClusterLengthMPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 0, 1 >, ElementwiseInputSequence, ck::Sequence< CDEShuffleBlockTransferScalarPerVector_NPerBlock >, I1, I1 > GridwiseElementwise
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:245
BlockToCTileMap_KSplit_M00_N0_M01Adapt< MPerBlock, NPerBlock, CGridDesc_M_N > Block2ETileMapKSplit
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:242
static constexpr index_t B2E_M01
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:264
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:993
decltype(MakeDsGridPointer()) DsGridPointer
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:231
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp:831
Definition device_grouped_gemm_splitk.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129