block_fmha_bwd_dot_do_o.hpp Source File

block_fmha_bwd_dot_do_o.hpp Source File#

Composable Kernel: block_fmha_bwd_dot_do_o.hpp Source File
block_fmha_bwd_dot_do_o.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
13{
17
18 static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
19 static constexpr index_t kBlockSize = Problem::kBlockSize;
20 static constexpr index_t kVHeaddim = Problem::kVHeaddim;
21
22 static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
23 static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
24 static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
25
26 static constexpr index_t kAlignmentO =
27 kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
28 static constexpr index_t kAlignmentOGrad =
29 kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
30
31 CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
32
33 template <typename ODramBlockWindowTmp,
34 typename OGradDramBlockWindowTmp,
35 typename DDramBlockWindowTmp>
36 CK_TILE_HOST_DEVICE void operator()(const ODramBlockWindowTmp& o_dram_block_window_tmp,
37 const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
38 DDramBlockWindowTmp& d_dram_block_window_tmp,
39 float p_undrop) const
40 {
41 static_assert(
42 std::is_same_v<ODataType, remove_cvref_t<typename ODramBlockWindowTmp::DataType>> &&
43 std::is_same_v<OGradDataType,
45 std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
46 "wrong!");
47
48 static_assert(kBlockSize == ODramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
49 kBlockSize ==
50 OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
51 kBlockSize == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
52 "wrong!");
53
54 auto o_dram_window =
55 make_tile_window(o_dram_block_window_tmp.get_bottom_tensor_view(),
56 o_dram_block_window_tmp.get_window_lengths(),
57 o_dram_block_window_tmp.get_window_origin(),
58 Policy::template MakePreODramTileDistribution<Problem>());
59
60 auto o = load_tile(o_dram_window);
61
62 auto do_dram_window =
63 make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
64 do_dram_block_window_tmp.get_window_lengths(),
65 do_dram_block_window_tmp.get_window_origin(),
66 Policy::template MakePreOGradDramTileDistribution<Problem>());
67
68 auto do_ = load_tile(do_dram_window);
69
70 // declare d
71 constexpr auto d_dstr =
73 o.get_tile_distribution().get_static_tile_distribution_encoding(), sequence<1>{}));
74
76
77 clear_tile(d); // Initialize D
78
79 constexpr auto o_spans = decltype(o)::get_distributed_spans();
80 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
81 constexpr auto i_idx = make_tuple(idx0);
82 sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
83 constexpr auto i_j_idx = make_tuple(idx0, idx1);
84 d(i_idx) +=
85 (type_convert<DDataType>(o[i_j_idx]) * type_convert<DDataType>(do_[i_j_idx]));
86 });
87 });
88
89 tile_elementwise_inout([&p_undrop](auto& x) { x = x * p_undrop; }, d);
90
91 store_tile(d_dram_block_window_tmp, d);
92 }
93};
94
95} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_HOST_DEVICE constexpr auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition tile_distribution_encoding.hpp:762
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_fmha_bwd_dot_do_o.hpp:13
remove_cvref_t< typename Problem::DDataType > DDataType
Definition block_fmha_bwd_dot_do_o.hpp:16
static constexpr index_t kBlockSize
Definition block_fmha_bwd_dot_do_o.hpp:19
static constexpr bool kPadHeadDimV
Definition block_fmha_bwd_dot_do_o.hpp:24
static constexpr index_t kBlockPerCu
Definition block_fmha_bwd_dot_do_o.hpp:18
remove_cvref_t< typename Problem::OGradDataType > OGradDataType
Definition block_fmha_bwd_dot_do_o.hpp:15
static constexpr bool kPadSeqLenQ
Definition block_fmha_bwd_dot_do_o.hpp:23
static constexpr index_t kVHeaddim
Definition block_fmha_bwd_dot_do_o.hpp:20
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_bwd_dot_do_o.hpp:14
static constexpr index_t kAlignmentO
Definition block_fmha_bwd_dot_do_o.hpp:26
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_bwd_dot_do_o.hpp:31
CK_TILE_HOST_DEVICE void operator()(const ODramBlockWindowTmp &o_dram_block_window_tmp, const OGradDramBlockWindowTmp &do_dram_block_window_tmp, DDramBlockWindowTmp &d_dram_block_window_tmp, float p_undrop) const
Definition block_fmha_bwd_dot_do_o.hpp:36
static constexpr index_t kAlignmentOGrad
Definition block_fmha_bwd_dot_do_o.hpp:28
static constexpr bool kIsGroupMode
Definition block_fmha_bwd_dot_do_o.hpp:22
Definition tile/core/container/sequence.hpp:49