DeviceGemmMX_Xdl_CShuffleV3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, CDataType, GemmAccDataType, CShuffleDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > Struct Template Reference#
WIP: Implements XDL CShuffle V3 GEMM for microscale-compliant data types. More...
#include <device_gemm_xdl_cshuffle_v3_mx.hpp>
Classes | |
| struct | Invoker |
Public Types | |
| template<index_t NXdlPerWave_> | |
| using | GridwiseGemmMXBase |
| template<index_t NXdlPerWave_> | |
| using | GridwiseGemmMXBPreshuffleBase |
| using | GridwiseGemm64 |
| using | GridwiseGemm32 |
| using | Argument = typename GridwiseGemm64::Argument |
Public Member Functions | |
| bool | IsSupportedArgument (const BaseArgument *p_arg) override |
| std::unique_ptr< BaseArgument > | MakeArgumentPointer (const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideScaleA, ck::index_t StrideB, ck::index_t StrideScaleB, ck::index_t StrideC, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override |
| std::unique_ptr< BaseInvoker > | MakeInvokerPointer () override |
| std::string | GetTypeString () const override |
| Public Member Functions inherited from ck::tensor_operation::device::BaseOperator | |
| BaseOperator ()=default | |
| BaseOperator (const BaseOperator &)=default | |
| BaseOperator & | operator= (const BaseOperator &)=default |
| virtual std::string | GetInstanceString () const |
| virtual std::string | GetTypeIdName () const |
| virtual std::optional< std::string > | GetObjectName () const |
| virtual std::optional< std::string > | GetTemplateInfo () const |
| virtual std::string | GetTypeIdHashCode () const |
| virtual size_t | GetWorkSpaceSize (const BaseArgument *) const |
| virtual void | SetWorkSpacePointer (BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const |
| virtual | ~BaseOperator () |
Static Public Member Functions | |
| static constexpr bool | IsValidCompilationParameter () |
| static bool | IsSupportedArgument (const Argument &arg) |
| static auto | MakeArgument (const ADataType *p_a, const AScaleDataType *p_a_scale, const BDataType *p_b, const BScaleDataType *p_b_scale, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) |
| static auto | MakeInvoker () |
Static Public Attributes | |
| static GET_NXDL_PER_WAVE_IMPL constexpr auto | NXdlPerWave64 = GetNXdlPerWave<true>() |
| static constexpr auto | NXdlPerWave32 = GetNXdlPerWave<false>() |
Detailed Description
struct ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, CDataType, GemmAccDataType, CShuffleDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >
WIP: Implements XDL CShuffle V3 GEMM for microscale-compliant data types.
This class is a work-in-progress implementation of the XDL CShuffle V3 GEMM for microscale-compliant data types.
Assumptions:
- A and B data types are compliant with the OCP Microscaling Formats (MX) Specification
- Each scale applies to ScaleBlockSize elements in K direction
- A scale matrix is a row-major
- B scale matrix is a column-major
- Scale data types must have get_exponent_value() specialization, whereas lowest 8 bits of the exponent will be interpreted as conventional biased Float32 exponent (E8M0)
Tunable parameters. The CK instance includes a series of tunable template parameters to control the parallel granularity of the workload to achieve load balancing on different hardware platforms. These parameters include Block Size, M/N/K Per Block, M/N per XDL, AK1, BK1, etc.
- Block Size determines the number of threads in the thread block.
- M/N/K Per Block determines the size of tile that each thread block is responsible for calculating.
- M/N Per XDL refers to M/N size for Instinct accelerator Matrix Fused Multiply Add (MFMA) instructions operating on a per-wavefront basis.
- A/B K1 is related to the data type. It can be any value ranging from 1 to K Per Block. To achieve the optimal load/store performance, 128bit per load is suggested. In addition, the A/B loading parameters must be changed accordingly to match the A/B K1 value; otherwise, it will result in compilation errors.
Conditions for achieving computational load balancing on different hardware platforms can vary.
- Template Parameters
-
KPerBlock is the number of elements in K dimension that each block processes (multiply with packed_size_v to get the actual KPerBlock)
Serialized version of the algorithm:
Member Typedef Documentation
◆ Argument
| using ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, CDataType, GemmAccDataType, CShuffleDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::Argument = typename GridwiseGemm64::Argument |
◆ GridwiseGemm32
| using ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, CDataType, GemmAccDataType, CShuffleDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::GridwiseGemm32 |
◆ GridwiseGemm64
| using ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, CDataType, GemmAccDataType, CShuffleDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::GridwiseGemm64 |
◆ GridwiseGemmMXBase
| using ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, CDataType, GemmAccDataType, CShuffleDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::GridwiseGemmMXBase |
◆ GridwiseGemmMXBPreshuffleBase
| using ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, CDataType, GemmAccDataType, CShuffleDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::GridwiseGemmMXBPreshuffleBase |
Member Function Documentation
◆ GetTypeString()
|
inlineoverridevirtual |
Reimplemented from ck::tensor_operation::device::BaseOperator.
◆ IsSupportedArgument() [1/2]
|
inlinestatic |
◆ IsSupportedArgument() [2/2]
|
inlineoverridevirtual |
Reimplemented from ck::tensor_operation::device::BaseOperator.
◆ IsValidCompilationParameter()
|
inlinestaticconstexpr |
◆ MakeArgument()
|
inlinestatic |
◆ MakeArgumentPointer()
|
inlineoverridevirtual |
◆ MakeInvoker()
|
inlinestatic |
◆ MakeInvokerPointer()
|
inlineoverridevirtual |
Member Data Documentation
◆ NXdlPerWave32
|
staticconstexpr |
◆ NXdlPerWave64
|
staticconstexpr |
The documentation for this struct was generated from the following file: