block_fmha_pipeline_qs_ks_vs_default_policy.hpp Source File

block_fmha_pipeline_qs_ks_vs_default_policy.hpp Source File#

Composable Kernel: block_fmha_pipeline_qs_ks_vs_default_policy.hpp Source File
block_fmha_pipeline_qs_ks_vs_default_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// This pipeline is qkv all located in LDS
13 : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ false,
14 /* AsyncCopy = */ false,
15 /* NumPrefetchK = */ 1,
16 /* NumPrefetchV = */ 1>
17{
18 template <typename Problem>
20 {
21 return MakeKLdsBlockDescriptor<Problem>().get_element_space_size() *
22 sizeof(typename Problem::KDataType);
23 } // namespace ck_tile
24
25 template <typename Problem>
27 {
28 return MakeVLdsBlockDescriptor<Problem>().get_element_space_size() *
29 sizeof(typename Problem::VDataType);
30 }
31
32 template <typename Problem>
34 {
35 return max(GetSmemSizeQ<Problem>() + GetSmemSizeK<Problem>(), GetSmemSizeV<Problem>()) +
37 }
38};
39
40} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
int32_t index_t
Definition integer.hpp:9
Definition block_fmha_pipeline_qs_ks_vs_default_policy.hpp:17
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeK()
Definition block_fmha_pipeline_qs_ks_vs_default_policy.hpp:19
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeV()
Definition block_fmha_pipeline_qs_ks_vs_default_policy.hpp:26
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_pipeline_qs_ks_vs_default_policy.hpp:33
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:266
static CK_TILE_HOST_DEVICE constexpr std::enable_if_t< std::is_convertible_v< decltype(Problem::kHasDropout), bool >, ck_tile::index_t > GetSmemSizeDropout(int)
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:687
static CK_TILE_HOST_DEVICE constexpr auto MakeVLdsBlockDescriptor()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:620
static CK_TILE_HOST_DEVICE constexpr auto MakeKLdsBlockDescriptor()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:486