BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ > Struct Template Reference#
Public Types |
Public Member Functions |
Static Public Member Functions |
Static Public Attributes |
List of all members
ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ > Struct Template Reference
#include <block_fmha_pipeline_qr_ks_vs_fp8.hpp>
Public Types | |
| using | Problem = remove_cvref_t<Problem_> |
| using | Policy = remove_cvref_t<Policy_> |
| using | QDataType = remove_cvref_t<typename Problem::QDataType> |
| using | KDataType = remove_cvref_t<typename Problem::KDataType> |
| using | VDataType = remove_cvref_t<typename Problem::VDataType> |
| using | SaccDataType = remove_cvref_t<typename Problem::SaccDataType> |
| using | SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType> |
| using | BiasDataType = remove_cvref_t<typename Problem::BiasDataType> |
| using | RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType> |
| using | LSEDataType = remove_cvref_t<typename Problem::LSEDataType> |
| using | PDataType = remove_cvref_t<typename Problem::PDataType> |
| using | OaccDataType = remove_cvref_t<typename Problem::OaccDataType> |
| using | ODataType = remove_cvref_t<typename Problem::ODataType> |
| using | FmhaMask = remove_cvref_t<typename Problem::FmhaMask> |
| using | BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape> |
| using | VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout> |
Public Member Functions | |
| template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename PositionEncoding> | |
| CK_TILE_HOST_DEVICE auto | operator() (const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, RandValDramBlockWindowTmp &, LSEDramBlockWindowTmp &, FmhaMask mask, PositionEncoding, float scale_s, float descale_qk, float descale_sv, void *smem_ptr, BlockDropout &) const |
Static Public Member Functions | |
| static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t | GetSmemSize () |
Static Public Attributes | |
| static constexpr bool | kQLoadOnce = true |
| static constexpr index_t | kBlockSize = Problem::kBlockSize |
| static constexpr index_t | kM0 = BlockFmhaShape::kM0 |
| static constexpr index_t | kN0 = BlockFmhaShape::kN0 |
| static constexpr index_t | kK0 = BlockFmhaShape::kK0 |
| static constexpr index_t | kN1 = BlockFmhaShape::kN1 |
| static constexpr index_t | kK1 = BlockFmhaShape::kK1 |
| static constexpr index_t | kQKHeaddim = BlockFmhaShape::kQKHeaddim |
| static constexpr bool | kIsGroupMode = Problem::kIsGroupMode |
| static constexpr bool | kPadSeqLenQ = Problem::kPadSeqLenQ |
| static constexpr bool | kPadSeqLenK = Problem::kPadSeqLenK |
| static constexpr bool | kPadHeadDimQ = Problem::kPadHeadDimQ |
| static constexpr bool | kPadHeadDimV = Problem::kPadHeadDimV |
| static constexpr auto | BiasEnum = Problem::BiasEnum |
| static constexpr bool | kStoreLSE = Problem::kStoreLSE |
| static constexpr bool | kHasDropout = Problem::kHasDropout |
| static constexpr index_t | kAlignmentQ |
| static constexpr index_t | kAlignmentK |
| static constexpr index_t | kAlignmentV |
| static constexpr index_t | kAlignmentO |
| static constexpr index_t | kAlignmentBias |
| static constexpr index_t | kBlockPerCu |
| static constexpr const char * | name = "qr_fp8" |
Member Typedef Documentation
◆ BiasDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::BiasDataType = remove_cvref_t<typename Problem::BiasDataType> |
◆ BlockFmhaShape
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape> |
◆ FmhaMask
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::FmhaMask = remove_cvref_t<typename Problem::FmhaMask> |
◆ KDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::KDataType = remove_cvref_t<typename Problem::KDataType> |
◆ LSEDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::LSEDataType = remove_cvref_t<typename Problem::LSEDataType> |
◆ OaccDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::OaccDataType = remove_cvref_t<typename Problem::OaccDataType> |
◆ ODataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::ODataType = remove_cvref_t<typename Problem::ODataType> |
◆ PDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::PDataType = remove_cvref_t<typename Problem::PDataType> |
◆ Policy
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::Policy = remove_cvref_t<Policy_> |
◆ Problem
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::Problem = remove_cvref_t<Problem_> |
◆ QDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::QDataType = remove_cvref_t<typename Problem::QDataType> |
◆ RandValOutputDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType> |
◆ SaccDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::SaccDataType = remove_cvref_t<typename Problem::SaccDataType> |
◆ SMPLComputeDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType> |
◆ VDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::VDataType = remove_cvref_t<typename Problem::VDataType> |
◆ VLayout
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSFp8< Problem_, Policy_ >::VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout> |
Member Function Documentation
◆ GetSmemSize()
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
inlinestaticconstexpr |
◆ operator()()
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename PositionEncoding>
|
inline |
NOTICE: bias might be materialized mask including -inf values, need consideration
Member Data Documentation
◆ BiasEnum
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kAlignmentBias
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>()
static constexpr bool kPadSeqLenK
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:64
◆ kAlignmentK
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>()
static constexpr index_t kPadHeadDimQ
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:52
◆ kAlignmentO
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>()
static constexpr bool kPadHeadDimV
Definition block_fmha_bwd_dot_do_o.hpp:24
◆ kAlignmentQ
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>()
◆ kAlignmentV
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
Initial value:
= []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}()
static constexpr index_t kPadHeadDimV
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:53
remove_cvref_t< Policy_ > Policy
Definition block_fmha_fwd_appendkv_pipeline.hpp:16
remove_cvref_t< Problem_ > Problem
Definition block_fmha_fwd_appendkv_pipeline.hpp:15
◆ kBlockPerCu
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
Initial value:
= []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
{
return 2;
}
{
return 3;
}
{
return 1;
else
return 2;
}
{
return 1;
}
}
}()
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
static constexpr index_t kQKHeaddim
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:46
static constexpr auto BiasEnum
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:51
static constexpr index_t kQKHeaddim
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:44
◆ kBlockSize
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kHasDropout
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kIsGroupMode
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kK0
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kK1
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kM0
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kN0
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kN1
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kPadHeadDimQ
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kPadHeadDimV
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kPadSeqLenK
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kPadSeqLenQ
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kQKHeaddim
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kQLoadOnce
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kStoreLSE
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ name
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
The documentation for this struct was generated from the following file: