BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor > Struct Template Reference#
Kernel arguments for batched tensor contraction operations. More...
#include <batched_contraction_kernel.hpp>
Public Attributes | |
| const void * | a_ptr |
| Pointer to input tensor A. | |
| const void * | b_ptr |
| Pointer to input tensor B. | |
| std::array< const void *, NumDTensor > | ds_ptr |
| Array of pointers to auxiliary input tensors D. | |
| void * | e_ptr |
| Pointer to output tensor E. | |
| ck_tile::index_t | k_batch |
| Number of k-splits for split-K batching. | |
| ck_tile::index_t | M_dims [NumDimM] |
| M dimension sizes: [M0, M1, M2, ..., M_{NumDimM-1}]. | |
| ck_tile::index_t | N_dims [NumDimN] |
| N dimension sizes: [N0, N1, N2, ..., N_{NumDimN-1}]. | |
| ck_tile::index_t | K_dims [NumDimK] |
| K dimension sizes: [K0, K1, K2, ..., K_{NumDimK-1}]. | |
| ck_tile::index_t | G_dims [NumDimG] |
| G (batch) dimension sizes: [G0, G1, G2, ..., G_{NumDimG-1}]. | |
| ck_tile::index_t | batch_stride_A |
| Batch stride for tensor A. | |
| ck_tile::index_t | batch_stride_B |
| Batch stride for tensor B. | |
| ck_tile::index_t | batch_stride_E |
| Batch stride for tensor E. | |
| std::array< ck_tile::index_t, NumDTensor > | batch_stride_Ds |
| Batch strides for D tensors. | |
| ck_tile::index_t | G_total |
| Total batch size: G0 * G1 * ... * G_{NumDimG-1}. | |
| ck_tile::index_t | M_total |
| Total M dimension: M0 * M1 * ... * M_{NumDimM-1}. | |
| ck_tile::index_t | N_total |
| Total N dimension: N0 * N1 * ... * N_{NumDimN-1}. | |
| ck_tile::index_t | K_total |
| Total K dimension: K0 * K1 * ... * K_{NumDimK-1}. | |
| ck_tile::index_t | stride_A |
| Leading dimension stride for tensor A (row-major: K_total). | |
| ck_tile::index_t | stride_B |
| Leading dimension stride for tensor B (row-major: K_total). | |
| std::array< ck_tile::index_t, NumDTensor > | stride_Ds |
| Leading dimension strides for D tensors (row-major: N_total). | |
| ck_tile::index_t | stride_E |
| Leading dimension stride for tensor E (row-major: N_total). | |
Detailed Description
struct BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >
Kernel arguments for batched tensor contraction operations.
- Template Parameters
-
NumDimG Number of batch dimensions NumDimM Number of M (output row) dimensions NumDimN Number of N (output column) dimensions NumDimK Number of K (contraction) dimensions NumDTensor Number of auxiliary input D tensors. Default is 0.
Member Data Documentation
◆ a_ptr
| const void* BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >::a_ptr |
Pointer to input tensor A.
◆ b_ptr
| const void* BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >::b_ptr |
Pointer to input tensor B.
◆ batch_stride_A
| ck_tile::index_t BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >::batch_stride_A |
Batch stride for tensor A.
◆ batch_stride_B
| ck_tile::index_t BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >::batch_stride_B |
Batch stride for tensor B.
◆ batch_stride_Ds
| std::array<ck_tile::index_t, NumDTensor> BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >::batch_stride_Ds |
Batch strides for D tensors.
◆ batch_stride_E
| ck_tile::index_t BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >::batch_stride_E |
Batch stride for tensor E.
◆ ds_ptr
| std::array<const void*, NumDTensor> BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >::ds_ptr |
Array of pointers to auxiliary input tensors D.
◆ e_ptr
| void* BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >::e_ptr |
Pointer to output tensor E.
◆ G_dims
| ck_tile::index_t BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >::G_dims[NumDimG] |
G (batch) dimension sizes: [G0, G1, G2, ..., G_{NumDimG-1}].
◆ G_total
| ck_tile::index_t BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >::G_total |
Total batch size: G0 * G1 * ... * G_{NumDimG-1}.
◆ k_batch
| ck_tile::index_t BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >::k_batch |
Number of k-splits for split-K batching.
◆ K_dims
| ck_tile::index_t BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >::K_dims[NumDimK] |
K dimension sizes: [K0, K1, K2, ..., K_{NumDimK-1}].
◆ K_total
| ck_tile::index_t BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >::K_total |
Total K dimension: K0 * K1 * ... * K_{NumDimK-1}.
◆ M_dims
| ck_tile::index_t BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >::M_dims[NumDimM] |
M dimension sizes: [M0, M1, M2, ..., M_{NumDimM-1}].
◆ M_total
| ck_tile::index_t BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >::M_total |
Total M dimension: M0 * M1 * ... * M_{NumDimM-1}.
◆ N_dims
| ck_tile::index_t BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >::N_dims[NumDimN] |
N dimension sizes: [N0, N1, N2, ..., N_{NumDimN-1}].
◆ N_total
| ck_tile::index_t BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >::N_total |
Total N dimension: N0 * N1 * ... * N_{NumDimN-1}.
◆ stride_A
| ck_tile::index_t BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >::stride_A |
Leading dimension stride for tensor A (row-major: K_total).
◆ stride_B
| ck_tile::index_t BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >::stride_B |
Leading dimension stride for tensor B (row-major: K_total).
◆ stride_Ds
| std::array<ck_tile::index_t, NumDTensor> BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >::stride_Ds |
Leading dimension strides for D tensors (row-major: N_total).
◆ stride_E
| ck_tile::index_t BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >::stride_E |
Leading dimension stride for tensor E (row-major: N_total).
The documentation for this struct was generated from the following file: