BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ > Struct Template Reference

BlockFmhaPipelineQRKSVSAsync&lt; Problem_, Policy_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ > Struct Template Reference
ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ > Struct Template Reference

#include <block_fmha_pipeline_qr_ks_vs_async.hpp>

Public Types

using Problem = remove_cvref_t<Problem_>
using Policy = remove_cvref_t<Policy_>
using QDataType = remove_cvref_t<typename Problem::QDataType>
using KDataType = remove_cvref_t<typename Problem::KDataType>
using VDataType = remove_cvref_t<typename Problem::VDataType>
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>
using PDataType = remove_cvref_t<typename Problem::PDataType>
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>
using ODataType = remove_cvref_t<typename Problem::ODataType>
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>

Public Member Functions

template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, typename VElementFunction, typename BiasElementFunction, typename LSEElementFunction, typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, typename PositionEncoding, typename AttentionVariantParams, typename BlockIndices>
CK_TILE_HOST_DEVICE auto operator() (const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, const KElementFunction &, const VDramBlockWindowTmp &v_dram_block_window_tmp, const VElementFunction &v_element_func, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const BiasElementFunction &bias_element_func, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, DropoutType &dropout) const
template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename PositionEncoding, typename AttentionVariantParams, typename BlockIndices>
CK_TILE_HOST_DEVICE auto operator() (const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, DropoutType &dropout) const

Static Public Member Functions

static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize ()

Static Public Attributes

static constexpr bool kQLoadOnce = true
static constexpr index_t kBlockSize = Problem::kBlockSize
static constexpr index_t kM0 = BlockFmhaShape::kM0
static constexpr index_t kN0 = BlockFmhaShape::kN0
static constexpr index_t kK0 = BlockFmhaShape::kK0
static constexpr index_t kN1 = BlockFmhaShape::kN1
static constexpr index_t kK1 = BlockFmhaShape::kK1
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim
static constexpr bool kIsGroupMode = Problem::kIsGroupMode
static constexpr bool kPadSeqLenQ = true
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK
static constexpr bool kPadHeadDimQ = true
static constexpr bool kPadHeadDimV = true
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap
static constexpr auto BiasEnum = Problem::BiasEnum
static constexpr bool kStoreLSE = Problem::kStoreLSE
static constexpr bool kHasDropout = Problem::kHasDropout
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>()
static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>()
static constexpr index_t kAlignmentV
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>()
static constexpr index_t kAlignmentBias
static constexpr index_t kBlockPerCu
static constexpr const char * name = "qr_async"

Member Typedef Documentation

◆ AttentionVariant

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>

◆ BiasDataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::BiasDataType = remove_cvref_t<typename Problem::BiasDataType>

◆ BlockFmhaShape

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>

◆ DropoutType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>

◆ FmhaMask

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::FmhaMask = remove_cvref_t<typename Problem::FmhaMask>

◆ KDataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::KDataType = remove_cvref_t<typename Problem::KDataType>

◆ LSEDataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::LSEDataType = remove_cvref_t<typename Problem::LSEDataType>

◆ OaccDataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::OaccDataType = remove_cvref_t<typename Problem::OaccDataType>

◆ ODataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::ODataType = remove_cvref_t<typename Problem::ODataType>

◆ PDataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::PDataType = remove_cvref_t<typename Problem::PDataType>

◆ Policy

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::Policy = remove_cvref_t<Policy_>

◆ Problem

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::Problem = remove_cvref_t<Problem_>

◆ QDataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::QDataType = remove_cvref_t<typename Problem::QDataType>

◆ RandValOutputDataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>

◆ SaccDataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::SaccDataType = remove_cvref_t<typename Problem::SaccDataType>

◆ SMPLComputeDataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>

◆ VDataType

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::VDataType = remove_cvref_t<typename Problem::VDataType>

◆ VLayout

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>

Member Function Documentation

◆ GetSmemSize()

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
CK_TILE_HOST_DEVICE constexpr ck_tile::index_t ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::GetSmemSize ( )
inlinestaticconstexpr

◆ operator()() [1/2]

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename PositionEncoding, typename AttentionVariantParams, typename BlockIndices>
CK_TILE_HOST_DEVICE auto ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::operator() ( const QDramBlockWindowTmp & q_dram_block_window_tmp,
const KDramBlockWindowTmp & k_dram_block_window_tmp,
const VDramBlockWindowTmp & v_dram_block_window_tmp,
const BiasDramBlockWindowTmp & bias_dram_block_window_tmp,
RandValDramBlockWindowTmp & randval_dram_block_window_tmp,
LSEDramBlockWindowTmp & lse_dram_block_window_tmp,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant & variant,
const AttentionVariantParams & variant_params,
const BlockIndices & block_indices,
void * smem_ptr,
DropoutType & dropout ) const
inline

◆ operator()() [2/2]

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, typename VElementFunction, typename BiasElementFunction, typename LSEElementFunction, typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, typename PositionEncoding, typename AttentionVariantParams, typename BlockIndices>
CK_TILE_HOST_DEVICE auto ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::operator() ( const QDramBlockWindowTmp & q_dram_block_window_tmp,
const QElementFunction & q_element_func,
const KDramBlockWindowTmp & k_dram_block_window_tmp,
const KElementFunction & ,
const VDramBlockWindowTmp & v_dram_block_window_tmp,
const VElementFunction & v_element_func,
const BiasDramBlockWindowTmp & bias_dram_block_window_tmp,
const BiasElementFunction & bias_element_func,
RandValDramBlockWindowTmp & randval_dram_block_window_tmp,
LSEDramBlockWindowTmp & lse_dram_window_tmp,
const LSEElementFunction & lse_element_func,
const SAccElementFunction & s_acc_element_func,
const PComputeElementFunction & p_compute_element_func,
const OAccElementFunction & o_acc_element_func,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant & variant,
const AttentionVariantParams & variant_params,
const BlockIndices & block_indices,
void * smem_ptr,
DropoutType & dropout ) const
inline

NOTICE: bias might be materialized mask including -inf values, need consideration. alibi does not have this problem

Member Data Documentation

◆ BiasEnum

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
auto ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::BiasEnum = Problem::BiasEnum
staticconstexpr

◆ kAlignmentBias

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kAlignmentBias
staticconstexpr
Initial value:
=
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>()
static constexpr bool kPadSeqLenK
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:64

◆ kAlignmentK

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kAlignmentK = Policy::template GetAlignmentK<Problem>()
staticconstexpr

◆ kAlignmentO

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kAlignmentO = Policy::template GetAlignmentO<Problem>()
staticconstexpr

◆ kAlignmentQ

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kAlignmentQ = Policy::template GetAlignmentQ<Problem>()
staticconstexpr

◆ kAlignmentV

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kAlignmentV
staticconstexpr
Initial value:
= []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}()

◆ kBlockPerCu

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kBlockPerCu
staticconstexpr

◆ kBlockSize

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kBlockSize = Problem::kBlockSize
staticconstexpr

◆ kHasDropout

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
bool ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kHasDropout = Problem::kHasDropout
staticconstexpr

◆ kHasLogitsSoftCap

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
bool ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kHasLogitsSoftCap = Problem::kHasLogitsSoftCap
staticconstexpr

◆ kIsGroupMode

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
bool ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kIsGroupMode = Problem::kIsGroupMode
staticconstexpr

◆ kK0

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kK0 = BlockFmhaShape::kK0
staticconstexpr

◆ kK1

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kK1 = BlockFmhaShape::kK1
staticconstexpr

◆ kM0

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kM0 = BlockFmhaShape::kM0
staticconstexpr

◆ kN0

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kN0 = BlockFmhaShape::kN0
staticconstexpr

◆ kN1

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kN1 = BlockFmhaShape::kN1
staticconstexpr

◆ kPadHeadDimQ

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
bool ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kPadHeadDimQ = true
staticconstexpr

◆ kPadHeadDimV

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
bool ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kPadHeadDimV = true
staticconstexpr

◆ kPadSeqLenK

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
bool ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kPadSeqLenK = Problem::kPadSeqLenK
staticconstexpr

◆ kPadSeqLenQ

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
bool ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kPadSeqLenQ = true
staticconstexpr

◆ kQKHeaddim

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kQKHeaddim = BlockFmhaShape::kQKHeaddim
staticconstexpr

◆ kQLoadOnce

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
bool ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kQLoadOnce = true
staticconstexpr

◆ kStoreLSE

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
bool ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kStoreLSE = Problem::kStoreLSE
staticconstexpr

◆ kSubQKHeaddim

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
index_t ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim
staticconstexpr

◆ name

template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
const char* ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::name = "qr_async"
staticconstexpr

The documentation for this struct was generated from the following file: