40 template <
typename BlockGemm,
bool IsFwd = true,
typename RandValDramBlockWindowTmp>
45 (void)randval_dram_block_window_tmp;
46 (void)seqlen_qk_start;
57 unsigned long long seed,
61 bool is_store_randval_)
71 template <
typename BlockGemm,
bool IsFwd = true,
typename RandValDramBlockWindowTmp>
76 constexpr auto config =
77 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
79 constexpr bool IsWG32 = WG::kM == 32;
80 constexpr index_t MWarp = config.template at<1>();
81 constexpr index_t NWarp = config.template at<2>();
83 constexpr index_t kMPerBlock = BlockGemmShape::kM;
84 constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
85 constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
86 constexpr index_t kNPerStep = NWarp * WG::kN;
88 const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
89 auto randval_dram_window = [&]() {
93 randval_dram_block_window_tmp.get_bottom_tensor_view(),
95 {block_origin.at(number<0>{}), seqlen_qk_start});
100 randval_dram_block_window_tmp.get_bottom_tensor_view(),
102 {seqlen_qk_start, block_origin.at(number<1>{})});
106 return randval_dram_window;
109 template <
typename BlockGemm>
112 constexpr auto config =
113 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
115 constexpr bool IsWG32 = WG::kM == 32;
116 constexpr index_t MWarp = config.template at<1>();
117 constexpr index_t NWarp = config.template at<2>();
119 constexpr index_t kMPerBlock = BlockGemmShape::kM;
120 constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
121 constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
122 constexpr index_t kNPerStep = NWarp * WG::kN;
124 constexpr index_t kN0 = kNPerStep / kN1;
133 randval_lds_block_desc_0,
140 return randval_lds_block_desc;
143 template <
typename BlockGemm>
146 constexpr auto config =
147 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
149 constexpr bool IsWG32 = WG::kM == 32;
150 constexpr index_t MWarp = config.template at<1>();
151 constexpr index_t NWarp = config.template at<2>();
153 constexpr index_t kMPerBlock = BlockGemmShape::kM;
154 constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
155 constexpr index_t NIterPerWarp = 1;
168 constexpr auto randval_block_inner_part_dstr_encoding =
170 typename WG::BDataType,
171 typename WG::CDataType,
176 IsWG32>::CWarpDstrEncoding{};
178 constexpr auto randval_block_part_dstr_encode =
180 randval_block_inner_part_dstr_encoding);
185 template <
typename BlockGemm>
188 constexpr auto config =
189 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
191 constexpr bool IsWG32 = WG::kM == 32;
192 constexpr index_t MWarp = config.template at<1>();
193 constexpr index_t NWarp = config.template at<2>();
195 constexpr index_t kMPerBlock = BlockGemmShape::kM;
196 constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
197 constexpr index_t NIterPerWarp = 1;
207 constexpr auto randval_block_part_dstr_encode =
209 typename WG::CWarpDstrEncoding{});
214 template <
typename BlockGemm,
215 typename PComputeDataType,
216 typename RandValOutputDataType,
217 typename PComputeWindow,
218 typename RandValDramWindow>
221 PComputeWindow& p_compute,
222 RandValDramWindow& randval_dram_window)
const
224 constexpr auto config =
225 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
227 constexpr bool IsWG32 = WG::kM == 32;
228 constexpr index_t MWarp = config.template at<1>();
229 constexpr index_t NWarp = config.template at<2>();
231 constexpr index_t kMPerBlock = BlockGemmShape::kM;
232 constexpr index_t kNPerBlock = BlockGemmShape::kN;
233 constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
234 constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
235 constexpr index_t kNPerStep = NWarp * WG::kN;
245 auto randval_dist_generated =
248 const auto randval_lds_read_window =
250 randval_lds_window.get_window_lengths(),
251 randval_lds_window.get_window_origin(),
254 const index_t start_m0_idx = randval_dram_window.get_window_origin().at(
number<0>{});
258 auto generate_randval = [&](
auto i_m0,
auto i_n0) {
260 uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize];
261 const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp;
262 const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp);
267 const unsigned long long ph_subsequence =
271 static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
278 const unsigned long long ph_subsequence =
280 const index_t subtile_m0 = wg_m0 % 2;
287 if constexpr(MIterPerWarp == 1)
289 static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
291 random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1);
295 static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
304 if constexpr(MIterPerWarp == 1)
306 static_assert(randval_dist_generated.kThreadElementSpaceSize == 4);
308 random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0);
312 static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
314 random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0);
319 constexpr auto randval_dist_generated_spans =
320 decltype(randval_dist_generated)::get_distributed_spans();
321 int i_random_idx = 0;
325 randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
329 store_tile(randval_lds_window, randval_dist_generated);
331 const auto randval =
load_tile(randval_lds_read_window);
338 static_for<0, kMPerBlock / kMPerStep, 1>{}([&](
auto i_m0) {
339 static_for<0, kNPerBlock / kNPerStep, 1>{}([&](
auto i_n0) {
340 const auto randval = generate_randval(i_m0, i_n0);
343 store_tile(randval_dram_window, randval_store);
350 static_for<0, kMPerBlock / kMPerStep, 1>{}([&](
auto i_m0) {
351 static_for<0, kNPerBlock / kNPerStep, 1>{}([&](
auto i_n0) {
352 const auto randval = generate_randval(i_m0, i_n0);
354 constexpr auto randval_spans =
decltype(randval)::get_distributed_spans();
357 constexpr auto p_idx0 =
359 idx0.
impl_.template at<0>()>{};
360 constexpr auto p_idx1 =
362 idx1.
impl_.template at<1>(),
363 idx1.impl_.template at<2>()>{};
368 : PComputeDataType(0);
384template <
bool IsDropout_,
bool IsWG32_,
bool IsStoreRandval_>
387template <
bool IsWG32_,
bool IsStoreRandval_>
393 template <
typename BlockGemm,
bool IsFwd = false,
typename RandValDramBlockWindowTmp>
398 (void)randval_dram_block_window_tmp;
399 (void)seqlen_qk_start;
405template <
bool IsWG32_,
bool IsStoreRandval_>
414 unsigned long long seed,
415 unsigned long long offset,
420 detail::philox_per_tile)),
426 template <
typename BlockGemm,
bool IsFwd = false,
typename RandValDramBlockWindowTmp>
431 constexpr auto config =
432 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
434 constexpr bool IsWG32 = WG::kM == 32;
435 constexpr index_t MWarp = config.template at<1>();
436 constexpr index_t NWarp = config.template at<2>();
438 constexpr index_t kMPerBlock = BlockGemmShape::kM;
439 constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
440 constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
441 constexpr index_t kNPerStep = NWarp * WG::kN;
443 const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
444 auto randval_dram_window = [&]() {
448 randval_dram_block_window_tmp.get_bottom_tensor_view(),
450 {block_origin.at(number<0>{}), seqlen_qk_start});
455 randval_dram_block_window_tmp.get_bottom_tensor_view(),
457 {seqlen_qk_start, block_origin.at(number<1>{})});
461 return randval_dram_window;
464 template <
typename BlockGemm>
467 constexpr auto config =
468 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
470 constexpr bool IsWG32 = WG::kM == 32;
471 constexpr index_t MWarp = config.template at<1>();
472 constexpr index_t NWarp = config.template at<2>();
474 constexpr index_t kMPerBlock = BlockGemmShape::kM;
475 constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
476 constexpr index_t NIterPerWarp = 1;
486 constexpr auto randval_block_inner_part_dstr_encoding =
488 typename WG::BDataType,
489 typename WG::CDataType,
494 IsWG32>::CWarpDstrEncoding{};
496 std::is_same_v<
remove_cvref_t<
decltype(randval_block_inner_part_dstr_encoding)>,
497 typename WG::CWarpDstrEncoding>);
499 constexpr auto randval_block_part_dstr_encode =
501 randval_block_inner_part_dstr_encoding);
506 template <
typename BlockGemm,
507 typename RandValOutputDataType,
508 typename PComputeWindow,
509 typename RandValDramWindow>
512 PComputeWindow& p_compute,
513 RandValDramWindow& randval_dram_window)
const
515 constexpr auto config =
516 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
518 constexpr bool IsWG32 = WG::kM == 32;
519 constexpr index_t MWarp = config.template at<1>();
520 constexpr index_t NWarp = config.template at<2>();
522 constexpr index_t kMPerBlock = BlockGemmShape::kM;
523 constexpr index_t kNPerBlock = BlockGemmShape::kN;
524 constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
525 constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
526 constexpr index_t kNPerStep = NWarp * WG::kN;
529 auto randval_dist_generated =
535 auto generate_randval = [&](
auto i_m0,
auto i_n0) {
537 uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize];
538 const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp;
539 const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp);
544 const unsigned long long ph_subsequence =
548 static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
555 const unsigned long long ph_subsequence =
557 const index_t subtile_m0 = wg_m0 % 2;
564 if constexpr(MIterPerWarp == 1)
566 static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
568 random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1);
572 static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
581 if constexpr(MIterPerWarp == 1)
583 static_assert(randval_dist_generated.kThreadElementSpaceSize == 4);
585 random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0);
589 static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
591 random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0);
596 constexpr auto randval_dist_generated_spans =
597 decltype(randval_dist_generated)::get_distributed_spans();
598 int i_random_idx = 0;
602 randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
605 return randval_dist_generated;
608 static_for<0, kNPerBlock / kNPerStep, 1>{}([&](
auto i_n0) {
609 static_for<0, kMPerBlock / kMPerStep, 1>{}([&](
auto i_m0) {
610 const auto randval = generate_randval(i_m0, i_n0);
613 constexpr auto randval_spans =
decltype(randval)::get_distributed_spans();
617 constexpr auto p_idx0 =
619 idx0.
impl_.template at<0>(),
620 idx0.impl_.template at<1>(),
621 idx0.impl_.template at<2>()>{};
633 store_tile(randval_dram_window, randval_store);
Definition philox_rand.hpp:12
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t *out, const unsigned long long subsequence, const index_t idx) const
Definition philox_rand.hpp:75
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t *out, const unsigned long long subsequence, const index_t idx0, const index_t idx1) const
Definition philox_rand.hpp:56
CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t *out, const unsigned long long subsequence) const
Definition philox_rand.hpp:42
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
constexpr index_t philox_per_tile
Definition block_dropout.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE index_t get_lane_id()
Definition arch.hpp:101
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths &window_lengths)
Definition null_tile_window.hpp:66
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
unsigned char uint8_t
Definition stdint.h:124
static CK_TILE_HOST_DEVICE constexpr auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition block_dropout.hpp:395
static constexpr bool IsStoreRandval
Definition block_dropout.hpp:391
static constexpr bool IsDropout
Definition block_dropout.hpp:390
CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch, index_t i_head, index_t nheads, unsigned long long seed, unsigned long long offset, float rp_undrop_, uint8_t p_undrop_in_uint8_t_)
Definition block_dropout.hpp:411
const unsigned long long ph_seed
Definition block_dropout.hpp:648
static CK_TILE_HOST_DEVICE constexpr auto MakeRandValTileDistribution()
Definition block_dropout.hpp:465
static constexpr bool IsStoreRandval
Definition block_dropout.hpp:409
const uint8_t p_undrop_in_uint8_t
Definition block_dropout.hpp:651
const unsigned long long ph_head_offset
Definition block_dropout.hpp:649
static CK_TILE_HOST_DEVICE constexpr auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition block_dropout.hpp:428
static constexpr bool IsDropout
Definition block_dropout.hpp:408
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx, const index_t start_n0_idx, PComputeWindow &p_compute, RandValDramWindow &randval_dram_window) const
Definition block_dropout.hpp:510
const float rp_undrop
Definition block_dropout.hpp:650
Definition block_dropout.hpp:385
const uint8_t p_undrop_in_uint8_t
Definition block_dropout.hpp:378
CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch, index_t i_head, index_t nheads, unsigned long long seed, unsigned long long offset, float rp_undrop_, uint8_t p_undrop_in_uint8_t_, bool is_store_randval_)
Definition block_dropout.hpp:54
const float rp_undrop
Definition block_dropout.hpp:377
static CK_TILE_HOST_DEVICE constexpr auto MakeRandValTileDistribution()
Definition block_dropout.hpp:144
const unsigned long long ph_head_offset
Definition block_dropout.hpp:376
static CK_TILE_HOST_DEVICE constexpr auto MakeRandValLdsBlockDescriptor()
Definition block_dropout.hpp:110
const bool is_store_randval
Definition block_dropout.hpp:379
static CK_TILE_HOST_DEVICE constexpr auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition block_dropout.hpp:73
CK_TILE_HOST_DEVICE void Run(void *randval_ptr, const index_t start_n0_idx, PComputeWindow &p_compute, RandValDramWindow &randval_dram_window) const
Definition block_dropout.hpp:219
const unsigned long long ph_seed
Definition block_dropout.hpp:375
static CK_TILE_HOST_DEVICE constexpr auto MakeRandValLdsShuffleTileDistribution()
Definition block_dropout.hpp:186
Definition block_dropout.hpp:39
static CK_TILE_HOST_DEVICE constexpr auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition block_dropout.hpp:42
Definition coordinate_transform.hpp:1392
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution.hpp:42
static constexpr auto impl_
Definition tile_distribution.hpp:45
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192