device_multi_query_attention_forward_wmma.hpp Source File

device_multi_query_attention_forward_wmma.hpp Source File#

Composable Kernel: device_multi_query_attention_forward_wmma.hpp Source File
device_multi_query_attention_forward_wmma.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8#include <numeric>
9#include <initializer_list>
10#include <cstdlib>
11
12#include "ck/ck.hpp"
24
25namespace ck {
26namespace tensor_operation {
27namespace device {
28
29// Multi-Query Attention (MQA) kernel implementation
30// Assume number of head of K,V is 1.
31// Q [G0, G1, M, K] * K [G0, 1, K, N] = P [G0, G1, M, N]
32// P [G0, G1, M, N] * V [G0, 1, N, O] = Out [G0, G1, M, O]
33template <typename DeviceOp,
34 typename GridwiseOp,
35 typename ADataType,
36 typename B0DataType,
37 typename B1DataType,
38 typename CDataType,
39 typename AElementwiseOperation,
40 typename B0ElementwiseOperation,
41 typename AccElementwiseOperation,
42 typename B1ElementwiseOperation,
43 typename CElementwiseOperation,
44 bool HasMainKBlockLoop>
45__global__ void
46#if CK_USE_LAUNCH_BOUNDS
48#endif
49 kernel_multi_query_attention_wmma(const ADataType* __restrict__ p_a_grid,
50 const B0DataType* __restrict__ p_b0_grid,
51 const B1DataType* __restrict__ p_b1_grid,
52 CDataType* __restrict__ p_c_grid,
53 index_t M, // SequenceQ
54 index_t N, // SequenceK
55 index_t K, // HeadDim
56 index_t O, // SequenceK
57 index_t G0, // Batch
58 index_t G1, // HeadNum
59 float alpha,
60 bool input_permute,
61 bool output_permute)
62{
63#if(defined(__gfx11__) || defined(__gfx12__))
64
65 // clang-format off
66// ***************************************************
67 const auto q_head = G1;
68 const auto kv_head = 1;
69// Make Tensor Descriptors
70 constexpr index_t array_size = 4;
71 std::array<ck::index_t, array_size> a_gs_ms_ks_lengths{G0, q_head, M, K};
72 std::array<ck::index_t, array_size> a_gs_ms_ks_strides =
73 input_permute
74 ? std::array<ck::index_t, array_size>{M * q_head * K, K, q_head * K, 1} // A layout [G0, M, G1, K]
75 : std::array<ck::index_t, array_size>{q_head * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
76
77 std::array<ck::index_t, array_size> b0_gs_ns_ks_lengths{G0, kv_head, N, K};
78 std::array<ck::index_t, array_size> b0_gs_ns_ks_strides =
79 input_permute
80 ? std::array<ck::index_t, array_size>{N * kv_head * K, K, kv_head * K, 1} // B0 layout [G0, N, 1, K]
81 : std::array<ck::index_t, array_size>{kv_head * N * K, N * K, K, 1}; // B0 layout [G0, 1, N, K]
82
83 std::array<ck::index_t, array_size> b1_gs_os_ns_lengths{G0, kv_head, O, N};
84 std::array<ck::index_t, array_size> b1_gs_os_ns_strides =
85 input_permute
86 ? std::array<ck::index_t, array_size>{N * kv_head * O, O, 1, kv_head * O} // B1 layout [G0, N, 1, O]
87 : std::array<ck::index_t, array_size>{kv_head * N * O, N * O, 1, O}; // B1 layout [G0, 1, N, O]
88
89 std::array<ck::index_t, array_size> c_gs_ms_os_lengths{G0, q_head, M, O};
90 std::array<ck::index_t, array_size> c_gs_ms_os_strides =
91 output_permute
92 ? std::array<ck::index_t, array_size>{M * q_head * O, O, q_head * O, 1} // C layout [G0, M, G1, O]
93 : std::array<ck::index_t, array_size>{q_head * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
94
95 const auto a_element_op = AElementwiseOperation{};
96 const auto b0_element_op = B0ElementwiseOperation{};
97 const auto acc0_element_op = AccElementwiseOperation{alpha};
98 const auto b1_element_op = B1ElementwiseOperation{};
99 const auto c_element_op = CElementwiseOperation{};
100 // fail to reuse DeviceOp::MakeArgument() because of the __device__ function required.
101
102 const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
103 const auto b0_grid_desc =
104 DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
105 const auto b1_grid_desc =
106 DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
107 const auto c_grid_desc_m_n =
108 DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
109 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
110 GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
111 const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
112
113 const auto a_grid_desc_g_m_k =
114 DeviceOp::Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
115 const auto b0_grid_desc_g_l_k =
116 DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
117 const auto b1_grid_desc_g_n_l =
118 DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
119 const auto c_grid_desc_g_m_n =
120 DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
121 const auto compute_base_ptr_of_batch =
122 typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n};
123 index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{});
124 const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})};
125
126 // clang-format on
127 __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
128 const index_t num_blocks_per_batch =
129 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
130 const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
131
132 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
133 static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
134 const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane(
135 static_cast<long_index_t>(compute_base_ptr_of_batch.GetB0BasePtr(g_idx / G1)));
136 const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
137 static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx / G1)));
138 const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
139 static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
140
141 GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
142 p_b0_grid + b0_batch_offset,
143 p_b1_grid + b1_batch_offset,
144 p_c_grid + c_batch_offset,
145 p_shared,
146 a_grid_desc,
147 b0_grid_desc,
148 b1_grid_desc,
149 c_grid_desc_mblock_mperblock_nblock_nperblock,
150 a_element_op,
151 b0_element_op,
152 acc0_element_op,
153 b1_element_op,
154 c_element_op,
155 c0_matrix_mask,
156 block_2_ctile_map);
157#else
158 ignore = p_a_grid;
159 ignore = p_b0_grid;
160 ignore = p_b1_grid;
161 ignore = p_c_grid;
162 ignore = M;
163 ignore = N;
164 ignore = K;
165 ignore = O;
166 ignore = G0;
167 ignore = G1;
168 ignore = alpha;
169 ignore = input_permute;
170 ignore = output_permute;
171#endif // end of if (defined(__gfx11__))
172}
173
174// Computes C = A * B0 * B1
175// MN = MK * KL * LN
176// ^^^^^^ (Acc0)
177// ^^^^^^^^^^^ (Acc1)
178template <index_t NumDimG,
179 index_t NumDimM,
180 index_t NumDimL,
181 index_t NumDimK,
182 index_t NumDimN,
183 typename ADataType,
184 typename B0DataType,
185 typename B1DataType,
186 typename CDataType,
187 typename Acc0BiasDataType,
188 typename Acc0DataType,
189 typename Acc1BiasDataType,
190 typename Acc1DataType,
191 typename CShuffleDataType,
192 typename AElementwiseOperation,
193 typename B0ElementwiseOperation,
194 typename AccElementwiseOperation,
195 typename B1ElementwiseOperation,
196 typename CElementwiseOperation,
197 GemmSpecialization GemmSpec,
202 ck::index_t NumPrefetch,
203 ck::index_t BlockSize,
204 ck::index_t MPerBlock,
205 ck::index_t LPerBlock,
206 ck::index_t KPerBlock,
207 ck::index_t AK1,
208 ck::index_t BK1,
209 ck::index_t NPerBlock,
210 ck::index_t LTilePerBlock,
211 ck::index_t L1,
212 ck::index_t MPerWmma,
213 ck::index_t LPerWmma,
214 ck::index_t NPerWmma,
215 ck::index_t MRepeat,
216 ck::index_t LRepeat,
217 ck::index_t NRepeat,
218 typename ABlockTransferThreadClusterLengths_K0_M_K1,
219 typename ABlockTransferThreadClusterArrangeOrder,
220 typename ABlockTransferSrcAccessOrder,
221 ck::index_t ABlockTransferSrcVectorDim,
222 ck::index_t ABlockTransferSrcScalarPerVector,
223 ck::index_t ABlockTransferDstScalarPerVector_K1,
224 bool ABlockLdsAddExtraM,
225 typename B0BlockTransferThreadClusterLengths_K0_L_K1,
226 typename B0BlockTransferThreadClusterArrangeOrder,
227 typename B0BlockTransferSrcAccessOrder,
228 ck::index_t B0BlockTransferSrcVectorDim,
229 ck::index_t B0BlockTransferSrcScalarPerVector,
230 ck::index_t B0BlockTransferDstScalarPerVector_K1,
231 bool B0BlockLdsAddExtraL,
232 typename B1BlockTransferThreadClusterLengths_L0_N_L1,
233 typename B1BlockTransferThreadClusterArrangeOrder,
234 typename B1BlockTransferSrcAccessOrder,
235 ck::index_t B1BlockTransferSrcVectorDim,
236 ck::index_t B1BlockTransferSrcScalarPerVector,
237 ck::index_t B1BlockTransferDstScalarPerVector_L1,
238 bool B1BlockLdsAddExtraN,
239 index_t CShuffleMRepeatPerShuffle,
240 index_t CShuffleNRepeatPerShuffle,
241 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
242 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
243 MaskingSpecialization MaskingSpec,
248 NumDimM,
249 NumDimL,
250 NumDimK,
251 NumDimN,
252 ADataType,
253 B0DataType,
254 B1DataType,
255 CDataType,
256 Acc0BiasDataType,
257 Acc1BiasDataType,
258 AElementwiseOperation,
259 B0ElementwiseOperation,
260 AccElementwiseOperation,
261 B1ElementwiseOperation,
262 CElementwiseOperation,
263 MaskingSpec>
264{
265 static_assert(NumDimG > 0 && NumDimM > 0 && NumDimL > 0 && NumDimK > 0 && NumDimN > 0,
266 "Number of dimension must be greater than 0");
267
268 static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
269 static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
270
271 // TODO ANT: implement bias combination
272 static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
273
274 static constexpr index_t NumDimGemm0M = NumDimM;
275 static constexpr index_t NumDimGemm0N = NumDimL;
276 static constexpr index_t NumDimGemm0K = NumDimK;
277 static constexpr index_t NumDimGemm1M = NumDimM;
278 static constexpr index_t NumDimGemm1N = NumDimN;
279 static constexpr index_t NumDimGemm1K = NumDimL;
280
282
283 static constexpr auto I0 = Number<0>{};
284 static constexpr auto I1 = Number<1>{};
285 static constexpr auto I2 = Number<2>{};
286 static constexpr auto I3 = Number<3>{};
287 static constexpr auto I4 = Number<4>{};
288 static constexpr auto I5 = Number<5>{};
289 static constexpr auto I6 = Number<6>{};
290
291 static constexpr auto WmmaK = 16;
292
293 static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
294 static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
295 static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
296
297 static constexpr auto AEnableLds_auto = LWaves == 1 ? false : true;
298 static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true;
299 static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true;
300
301 static constexpr auto AEnableLds_manu = false;
302 static constexpr auto B0EnableLds_manu = true;
303 static constexpr auto B1EnableLds_manu = true;
304
305 static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
306 static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu || (NumPrefetch > 1);
307 static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu || (NumPrefetch > 1);
308
312 GemmSpec,
313 ASpec,
314 B0Spec,
315 B1Spec,
316 CSpec>;
317
318 __host__ __device__ static auto MakeAGridDescriptor(
319 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
320 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
321 {
322 if constexpr(AEnableLds)
323 {
325 Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
326 Number<AK1>{});
327 }
328 else
329 {
330 return Transform::
332 Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec,
333 a_gs_ms_ks_strides_vec),
338 Number<AK1>{});
339 }
340 }
341
342 __host__ __device__ static auto MakeB0GridDescriptor(
343 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_lengths_vec,
344 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_strides_vec)
345 {
346 if constexpr(B0EnableLds)
347 {
349 Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec,
350 b0_gs_ls_ks_strides_vec),
351 Number<BK1>{});
352 }
353 else
354 {
355 return Transform::
357 Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec,
358 b0_gs_ls_ks_strides_vec),
363 Number<BK1>{});
364 }
365 }
366
367 __host__ __device__ static auto MakeB1GridDescriptor(
368 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_lengths_vec,
369 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_strides_vec)
370 {
371 if constexpr(B1EnableLds)
372 {
374 Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec,
375 b1_gs_ns_ls_strides_vec),
376 Number<L1>{});
377 }
378 else
379 {
380 return Transform::
382 Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec,
383 b1_gs_ns_ls_strides_vec),
388 Number<L1>{});
389 }
390 }
391
392 using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
393 using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
394 using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
400
401 __host__ __device__ constexpr static auto make_MaskOutPredicate()
402 {
403 if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
404 {
405 return MaskDisabledPredicate{};
406 }
407 else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
408 {
410 }
411 }
413
415 {
416 __host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
417 const B0GridDesc_G_L_K& b0_grid_desc_g_l_k,
418 const B1GridDesc_G_N_L& b1_grid_desc_g_n_l,
419 const CGridDesc_G_M_N& c_grid_desc_g_m_n)
420 : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
421 b0_grid_desc_g_l_k_(b0_grid_desc_g_l_k),
422 b1_grid_desc_g_n_l_(b1_grid_desc_g_n_l),
423 c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
424 {
425 }
426
427 __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
428 {
429 return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
430 }
431
432 __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
433 {
434 return b0_grid_desc_g_l_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
435 }
436
437 __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
438 {
439 return b1_grid_desc_g_n_l_.CalculateOffset(make_multi_index(g_idx, 0, 0));
440 }
441
442 __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
443 {
444 return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
445 }
446
447 private:
448 AGridDesc_G_M_K a_grid_desc_g_m_k_;
449 B0GridDesc_G_L_K b0_grid_desc_g_l_k_;
450 B1GridDesc_G_N_L b1_grid_desc_g_n_l_;
451 CGridDesc_G_M_N c_grid_desc_g_m_n_;
452 };
453
454 // GridwiseOp
456 // DataType Family
457 ADataType,
458 B0DataType,
459 Acc0DataType,
460 B1DataType,
461 Acc1DataType,
462 CShuffleDataType,
463 CDataType,
464 // ElementwiseOp Family
465 AElementwiseOperation,
466 B0ElementwiseOperation,
467 AccElementwiseOperation,
468 B1ElementwiseOperation,
469 CElementwiseOperation,
471 // InMemory Data Descriptor
472 AGridDesc,
476 // Tiling Family
477 MPerBlock,
478 LPerBlock,
479 KPerBlock,
480 AK1,
481 BK1,
482 NPerBlock,
483 LTilePerBlock,
484 L1,
485 MPerWmma,
486 LPerWmma,
487 NPerWmma,
488 MRepeat,
489 LRepeat,
490 NRepeat,
491 // ThreadCluster Family
492 BlockSize,
493 ABlockTransferThreadClusterLengths_K0_M_K1,
494 ABlockTransferThreadClusterArrangeOrder,
495 ABlockTransferSrcAccessOrder,
496 ABlockTransferSrcVectorDim,
497 ABlockTransferSrcScalarPerVector,
498 ABlockTransferDstScalarPerVector_K1,
499 true,
501 ABlockLdsAddExtraM,
502 B0BlockTransferThreadClusterLengths_K0_L_K1,
503 B0BlockTransferThreadClusterArrangeOrder,
504 B0BlockTransferSrcAccessOrder,
505 B0BlockTransferSrcVectorDim,
506 B0BlockTransferSrcScalarPerVector,
507 B0BlockTransferDstScalarPerVector_K1,
508 true,
510 B0BlockLdsAddExtraL,
511 B1BlockTransferThreadClusterLengths_L0_N_L1,
512 B1BlockTransferThreadClusterArrangeOrder,
513 B1BlockTransferSrcAccessOrder,
514 B1BlockTransferSrcVectorDim,
515 B1BlockTransferSrcScalarPerVector,
516 B1BlockTransferDstScalarPerVector_L1,
517 false,
519 B1BlockLdsAddExtraN,
520 CShuffleMRepeatPerShuffle,
521 CShuffleNRepeatPerShuffle,
522 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
523 CShuffleBlockTransferScalarPerVector_NPerBlock,
526 NumPrefetch,
527 LoopSched,
528 PipelineVer>;
529
530 struct RawArg : public BaseArgument
531 {
532 RawArg(const ADataType* p_a_grid,
533 const B0DataType* p_b0_grid,
534 const B1DataType* p_b1_grid,
535 CDataType* p_c_grid,
536 index_t M,
537 index_t N,
538 index_t K,
539 index_t O,
540 index_t G0,
541 index_t G1,
542 float alpha,
543 bool input_permute,
544 bool output_permute)
545 : p_a_grid_{p_a_grid},
546 p_b0_grid_{p_b0_grid},
547 p_b1_grid_{p_b1_grid},
548 p_c_grid_{p_c_grid},
549 M_{M},
550 N_{N},
551 K_{K},
552 O_{O},
553 G0_{G0},
554 G1_{G1},
555 alpha_{alpha},
556 input_permute_{input_permute},
557 output_permute_{output_permute}
558 {
559 }
560 // Pointers
561 const ADataType* p_a_grid_;
562 const B0DataType* p_b0_grid_;
563 const B1DataType* p_b1_grid_;
564 CDataType* p_c_grid_;
565
566 // Raw Problem Size
573 float alpha_;
576 };
577
578 static auto MakeArgument(const ADataType* p_a,
579 const B0DataType* p_b0,
580 const B1DataType* p_b1,
581 CDataType* p_c,
582 index_t M,
583 index_t N,
584 index_t K,
585 index_t O,
586 index_t G0,
587 index_t G1,
588 float alpha,
589 bool input_permute,
590 bool output_permute)
591 {
592 return RawArg{
593 p_a, p_b0, p_b1, p_c, M, N, K, O, G0, G1, alpha, input_permute, output_permute};
594 }
595
596 static bool IsSupportedArgument(const RawArg& arg)
597 {
599 {
601 {
602 printf("DeviceOp: Acc0 Type err");
603 return false;
604 }
605
607 {
608 printf("DeviceOp: Acc1 Type err");
609 return false;
610 }
611 }
612 else
613 {
614 printf("DeviceOp: Arch err");
615 return false;
616 }
617
618 constexpr index_t array_size = 4;
619 ck::index_t G0 = arg.G0_;
620 ck::index_t G1 = arg.G1_;
621 ck::index_t M = arg.M_;
622 ck::index_t N = arg.N_;
623 ck::index_t K = arg.K_;
624 ck::index_t O = arg.O_;
625 bool input_permute = arg.input_permute_;
626 bool output_permute = arg.output_permute_;
627
628 std::array<ck::index_t, array_size> a_gs_ms_ks_lengths{G0, G1, M, K};
629 std::array<ck::index_t, array_size> a_gs_ms_ks_strides =
630 input_permute ? std::array<ck::index_t, array_size>{M * G1 * K, K, G1 * K, 1}
631 // A layout [G0, M, G1, K]
632 : std::array<ck::index_t, array_size>{
633 G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
634
635 std::array<ck::index_t, array_size> b0_gs_ns_ks_lengths{G0, G1, N, K};
636 std::array<ck::index_t, array_size> b0_gs_ns_ks_strides =
637 input_permute ? std::array<ck::index_t, array_size>{N * G1 * K, K, G1 * K, 1}
638 // B0 layout [G0, N, G1, K]
639 : std::array<ck::index_t, array_size>{
640 G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K]
641
642 std::array<ck::index_t, array_size> b1_gs_os_ns_lengths{G0, G1, O, N};
643 std::array<ck::index_t, array_size> b1_gs_os_ns_strides =
644 input_permute ? std::array<ck::index_t, array_size>{N * G1 * O, O, 1, G1 * O}
645 // B1 layout [G0, N, G1, O]
646 : std::array<ck::index_t, array_size>{
647 G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O]
648
649 std::array<ck::index_t, array_size> c_gs_ms_os_lengths{G0, G1, M, O};
650 std::array<ck::index_t, array_size> c_gs_ms_os_strides =
651 output_permute ? std::array<ck::index_t, array_size>{M * G1 * O, O, G1 * O, 1}
652 // C layout [G0, M, G1, O]
653 : std::array<ck::index_t, array_size>{
654 G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
655
656 const auto a_grid_desc =
657 DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
658 const auto b0_grid_desc =
659 DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
660 const auto b1_grid_desc =
661 DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
662 const auto c_grid_desc_m_n =
663 DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
664
665 const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
666
667 const auto c_grid_desc_g_m_n =
668 DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
669 index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{});
670
672 a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n, block_2_ctile_map))
673 {
674 return false;
675 }
676
677 // Check if C permute dimension matches GEMM + GEMM shape
678 const index_t c_g = c_grid_desc_g_m_n.GetLength(I0); // unpadded
679
680 if(!(c_g == batch_count))
681 {
682 printf("DeviceOp: BatchCount err");
683 return false;
684 }
685
686 // Note: we need raw lengths since threadwise copy can not handle vector load when part of
687 // vector is out of bounds
688 // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
689 const auto MzRaw = M;
690 const auto LzRaw = N;
691 const auto KzRaw = K;
692 const auto NzRaw = O;
693
694 // Check scalar per vector requirement
695 const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
696 const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw;
697 const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw;
698 const auto c_extent_lowest = NzRaw;
699
700 if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
701 b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
702 b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
703 c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
704 {
705 printf("DeviceOp: Data Transfer Vector scalar err");
706 return false;
707 }
708
709 std::array<index_t, NumDimG + NumDimM + NumDimN> a_mz_kz_strides_{
710 a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
711 a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]};
712 std::array<index_t, NumDimG + NumDimM + NumDimN> b0_lz_kz_strides_{
713 b0_gs_ns_ks_strides[NumDimG + NumDimL - 1],
714 b0_gs_ns_ks_strides[NumDimG + NumDimL + NumDimK - 1]};
715 std::array<index_t, NumDimG + NumDimM + NumDimN> b1_nz_lz_strides_{
716 b1_gs_os_ns_strides[NumDimG + NumDimN - 1],
717 b1_gs_os_ns_strides[NumDimG + NumDimN + NumDimL - 1]};
718 std::array<index_t, NumDimG + NumDimM + NumDimN> c_mz_nz_strides_{
719 c_gs_ms_os_strides[NumDimG + NumDimM - 1],
720 c_gs_ms_os_strides[NumDimG + NumDimM + NumDimN - 1]};
721
722 // Check vector load/store requirement
723 const auto a_stride_lowest =
724 ABlockTransferSrcVectorDim == 2 ? a_mz_kz_strides_[1] : a_mz_kz_strides_[0];
725 const auto b0_stride_lowest =
726 B0BlockTransferSrcVectorDim == 2 ? b0_lz_kz_strides_[1] : b0_lz_kz_strides_[0];
727 const auto b1_stride_lowest =
728 B1BlockTransferSrcVectorDim == 2 ? b1_nz_lz_strides_[1] : b1_nz_lz_strides_[0];
729 const auto c_stride_lowest = c_mz_nz_strides_[1];
730
731 if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
732 c_stride_lowest == 1))
733 {
734 printf("DeviceOp: Data Vectorize transfer err");
735 return false;
736 }
737
738 return true;
739 }
740
741 // polymorphic
742 bool IsSupportedArgument(const BaseArgument* p_arg) override
743 {
744 return IsSupportedArgument(*dynamic_cast<const RawArg*>(p_arg));
745 }
746
747 // Argument
748 struct Argument : public BaseArgument
749 {
751 const ADataType* p_a_grid,
752 const B0DataType* p_b0_grid,
753 const B1DataType* p_b1_grid,
754 CDataType* p_c_grid,
755 const std::array<void*, NumAcc0Bias> p_acc0_biases,
756 const std::array<void*, NumAcc1Bias> p_acc1_biases,
757 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths,
758 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides,
759 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_lengths,
760 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_strides,
761 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_lengths,
762 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_strides,
763 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_lengths,
764 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_strides,
765 const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
766 const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
767 const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
768 const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_strides,
769 const index_t M01,
770 const index_t N01,
771 AElementwiseOperation a_element_op,
772 B0ElementwiseOperation b0_element_op,
773 AccElementwiseOperation acc_element_op,
774 B1ElementwiseOperation b1_element_op,
775 CElementwiseOperation c_element_op)
776 : p_a_grid_{p_a_grid},
777 p_b0_grid_{p_b0_grid},
778 p_b1_grid_{p_b1_grid},
779 p_c_grid_{p_c_grid},
780 a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
782 DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
784 DeviceOp::MakeB1GridDescriptor(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)},
786 Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)},
788 Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
790 Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
792 Transform::MakeB1GridDescriptor_G_N_K(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)},
794 Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)},
796 block_2_ctile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
797 a_element_op_{a_element_op},
798 b0_element_op_{b0_element_op},
799 acc_element_op_{acc_element_op},
800 b1_element_op_{b1_element_op},
801 c_element_op_{c_element_op},
803 raw_lengths_mz_lz_kz_nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
804 b0_gs_ls_ks_lengths[NumDimG + NumDimL - 1],
805 b0_gs_ls_ks_lengths[NumDimG + NumDimL + NumDimK - 1],
806 b1_gs_ns_ls_lengths[NumDimG + NumDimN - 1]},
807 a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
808 a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
809 b0_lz_kz_strides_{b0_gs_ls_ks_strides[NumDimG + NumDimL - 1],
810 b0_gs_ls_ks_strides[NumDimG + NumDimL + NumDimK - 1]},
811 b1_nz_lz_strides_{b1_gs_ns_ls_strides[NumDimG + NumDimN - 1],
812 b1_gs_ns_ls_strides[NumDimG + NumDimN + NumDimL - 1]},
813 c_mz_nz_strides_{c_gs_ms_ns_strides[NumDimG + NumDimM - 1],
814 c_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]},
818 {
819 // TODO ANT: implement bias addition
820 ignore = p_acc0_biases;
821 ignore = p_acc1_biases;
822 ignore = acc0_biases_gs_ms_ls_lengths;
823 ignore = acc0_biases_gs_ms_ls_strides;
824 ignore = acc1_biases_gs_ms_ns_lengths;
825 ignore = acc1_biases_gs_ms_ns_strides;
826
829 {
833 }
834 }
835
836 // Pointers
837 const ADataType* p_a_grid_;
838 const B0DataType* p_b0_grid_;
839 const B1DataType* p_b1_grid_;
840 CDataType* p_c_grid_;
841
842 // Tensor Descriptors
847
852
855
856 // Block to Tile mapping
858
859 // ElementwiseOp
860 AElementwiseOperation a_element_op_;
861 B0ElementwiseOperation b0_element_op_;
862 AccElementwiseOperation acc_element_op_;
863 B1ElementwiseOperation b1_element_op_;
864 CElementwiseOperation c_element_op_;
865
866 // check C0 masking and padding
868
869 // Strides for the last M/N/K dimensions of A/B0/B1/C
870 // for sanity check of vector load/store
871 std::array<index_t, NumDimG + NumDimM + NumDimN> raw_lengths_mz_lz_kz_nz_;
872 std::array<index_t, NumDimG + NumDimM + NumDimN> a_mz_kz_strides_;
873 std::array<index_t, NumDimG + NumDimM + NumDimN> b0_lz_kz_strides_;
874 std::array<index_t, NumDimG + NumDimM + NumDimN> b1_nz_lz_strides_;
875 std::array<index_t, NumDimG + NumDimM + NumDimN> c_mz_nz_strides_;
876
878 // Batch Offset
880 };
881
882 struct Invoker : public BaseInvoker
883 {
885
886 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
887 {
888 const auto M0 = math::integer_divide_ceil(arg.M_, MPerBlock);
889 const auto N0 = math::integer_divide_ceil(arg.O_, NPerBlock);
890
891 const index_t grid_size = arg.G0_ * arg.G1_ * M0 * N0;
892 const auto K = arg.K_;
893 // printf("HasKBlockLoop: %d\n", GridwiseOp::CalculateHasMainKBlockLoop(K));
894 auto launch_kernel = [&](auto has_main_k_block_loop) {
897 ADataType,
898 B0DataType,
899 B1DataType,
900 CDataType,
901 AElementwiseOperation,
902 B0ElementwiseOperation,
903 AccElementwiseOperation,
904 B1ElementwiseOperation,
905 CElementwiseOperation,
906 has_main_k_block_loop>;
907
908 return launch_and_time_kernel(stream_config,
909 kernel,
910 dim3(grid_size),
911 dim3(BlockSize),
912 0,
913 arg.p_a_grid_,
914 arg.p_b0_grid_,
915 arg.p_b1_grid_,
916 arg.p_c_grid_,
917 arg.M_,
918 arg.N_,
919 arg.K_,
920 arg.O_,
921 arg.G0_,
922 arg.G1_,
923 arg.alpha_,
924 arg.input_permute_,
925 arg.output_permute_);
926 };
927
929 {
930 return launch_kernel(integral_constant<bool, true>{});
931 }
932 else
933 {
934 return launch_kernel(integral_constant<bool, false>{});
935 }
936 }
937
938 // polymorphic
939 float Run(const BaseArgument* p_arg,
940 const StreamConfig& stream_config = StreamConfig{}) override
941 {
942 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
943 }
944 };
945
946 static constexpr bool IsValidCompilationParameter()
947 {
948 // TODO: properly implement this check
949 return true;
950 }
951#if 0
952 static bool IsSupportedArgument(const Argument& arg)
953 {
955 {
957 {
958 printf("DeviceOp: Acc0 Type err");
959 return false;
960 }
961
963 {
964 printf("DeviceOp: Acc1 Type err");
965 return false;
966 }
967 }
968 else
969 {
970 printf("DeviceOp: Arch err");
971 return false;
972 }
973
974 if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
975 arg.b0_grid_desc,
976 arg.b1_grid_desc,
977 arg.c_grid_desc_m_n_,
978 arg.block_2_ctile_map_))
979 {
980 return false;
981 }
982
983 // Check if C permute dimension matches GEMM + GEMM shape
984 const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
985
986 if(!(c_g == arg.batch_count_))
987 {
988 printf("DeviceOp: BatchCount err");
989 return false;
990 }
991
992 // Note: we need raw lengths since threadwise copy can not handle vector load when part of
993 // vector is out of bounds
994 // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
995 const auto MzRaw = arg.raw_lengths_mz_lz_kz_nz_[0];
996 const auto LzRaw = arg.raw_lengths_mz_lz_kz_nz_[1];
997 const auto KzRaw = arg.raw_lengths_mz_lz_kz_nz_[2];
998 const auto NzRaw = arg.raw_lengths_mz_lz_kz_nz_[3];
999
1000 // Check scalar per vector requirement
1001 const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
1002 const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw;
1003 const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw;
1004 const auto c_extent_lowest = NzRaw;
1005
1006 if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
1007 b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
1008 b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
1009 c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
1010 {
1011 printf("DeviceOp: Data Transfer Vector scalar err");
1012 return false;
1013 }
1014
1015 // Check vector load/store requirement
1016 const auto a_stride_lowest =
1017 ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
1018 const auto b0_stride_lowest =
1019 B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0];
1020 const auto b1_stride_lowest =
1021 B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0];
1022 const auto c_stride_lowest = arg.c_mz_nz_strides_[1];
1023
1024 if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
1025 c_stride_lowest == 1))
1026 {
1027 printf("DeviceOp: Data Vectorize transfer err");
1028 return false;
1029 }
1030
1031 return true;
1032 }
1033
1034 // polymorphic
1035 bool IsSupportedArgument(const BaseArgument* p_arg) override
1036 {
1037 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
1038 }
1039
1040 static auto MakeArgument(
1041 const ADataType* p_a,
1042 const B0DataType* p_b0,
1043 const B1DataType* p_b1,
1044 CDataType* p_c,
1045 const std::array<void*, NumAcc0Bias> p_acc0_biases,
1046 const std::array<void*, NumAcc1Bias> p_acc1_biases,
1047 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths,
1048 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides,
1049 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_lengths,
1050 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_strides,
1051 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_lengths,
1052 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_strides,
1053 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_lengths,
1054 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_strides,
1055 const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
1056 const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
1057 const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
1058 const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_strides,
1059 AElementwiseOperation a_element_op,
1060 B0ElementwiseOperation b0_element_op,
1061 AccElementwiseOperation acc_element_op,
1062 B1ElementwiseOperation b1_element_op,
1063 CElementwiseOperation c_element_op)
1064 {
1065 return Argument{p_a,
1066 p_b0,
1067 p_b1,
1068 p_c,
1069 p_acc0_biases,
1070 p_acc1_biases,
1071 a_gs_ms_ks_lengths,
1072 a_gs_ms_ks_strides,
1073 b0_gs_ls_ks_lengths,
1074 b0_gs_ls_ks_strides,
1075 b1_gs_ns_ls_lengths,
1076 b1_gs_ns_ls_strides,
1077 c_gs_ms_ns_lengths,
1078 c_gs_ms_ns_strides,
1079 acc0_biases_gs_ms_ls_lengths,
1080 acc0_biases_gs_ms_ls_strides,
1081 acc1_biases_gs_ms_ns_lengths,
1082 acc1_biases_gs_ms_ns_strides,
1083 1,
1084 1,
1085 a_element_op,
1086 b0_element_op,
1087 acc_element_op,
1088 b1_element_op,
1089 c_element_op};
1090 }
1091#endif
1092
1093 // polymorphic
1094 std::unique_ptr<BaseArgument> MakeArgumentPointer(
1095 const void* p_a,
1096 const void* p_b0,
1097 const void* p_b1,
1098 void* p_c,
1099 const std::array<void*, NumAcc0Bias> p_acc0_biases,
1100 const std::array<void*, NumAcc1Bias> p_acc1_biases,
1101 const std::vector<index_t>& a_gs_ms_ks_lengths,
1102 const std::vector<index_t>& a_gs_ms_ks_strides,
1103 const std::vector<index_t>& b0_gs_ls_ks_lengths,
1104 const std::vector<index_t>& b0_gs_ls_ks_strides,
1105 const std::vector<index_t>& b1_gs_ns_ls_lengths,
1106 const std::vector<index_t>& b1_gs_ns_ls_strides,
1107 const std::vector<index_t>& c_gs_ms_ns_lengths,
1108 const std::vector<index_t>& c_gs_ms_ns_strides,
1109 const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
1110 const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
1111 const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
1112 const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_strides,
1113 AElementwiseOperation a_element_op,
1114 B0ElementwiseOperation b0_element_op,
1115 AccElementwiseOperation acc_element_op,
1116 B1ElementwiseOperation b1_element_op,
1117 CElementwiseOperation c_element_op) override
1118 {
1119 std::array<index_t, NumDimG + NumDimM + NumDimN> a_lengths;
1120 std::array<index_t, NumDimG + NumDimM + NumDimN> a_strides;
1121 std::array<index_t, NumDimG + NumDimM + NumDimN> b0_lengths;
1122 std::array<index_t, NumDimG + NumDimM + NumDimN> b0_strides;
1123 std::array<index_t, NumDimG + NumDimM + NumDimN> b1_lengths;
1124 std::array<index_t, NumDimG + NumDimM + NumDimN> b1_strides;
1125 std::array<index_t, NumDimG + NumDimM + NumDimN> c_lengths;
1126 std::array<index_t, NumDimG + NumDimM + NumDimN> c_strides;
1127 std::transform(a_gs_ms_ks_lengths.begin(),
1128 a_gs_ms_ks_lengths.end(),
1129 a_lengths.begin(),
1130 [](index_t i) { return i; });
1131 std::transform(a_gs_ms_ks_strides.begin(),
1132 a_gs_ms_ks_strides.end(),
1133 a_strides.begin(),
1134 [](index_t i) { return i; });
1135 std::transform(b0_gs_ls_ks_lengths.begin(),
1136 b0_gs_ls_ks_lengths.end(),
1137 b0_lengths.begin(),
1138 [](index_t i) { return i; });
1139 std::transform(b0_gs_ls_ks_strides.begin(),
1140 b0_gs_ls_ks_strides.end(),
1141 b0_strides.begin(),
1142 [](index_t i) { return i; });
1143 std::transform(b1_gs_ns_ls_lengths.begin(),
1144 b1_gs_ns_ls_lengths.end(),
1145 b1_lengths.begin(),
1146 [](index_t i) { return i; });
1147 std::transform(b1_gs_ns_ls_strides.begin(),
1148 b1_gs_ns_ls_strides.end(),
1149 b1_strides.begin(),
1150 [](index_t i) { return i; });
1151 std::transform(c_gs_ms_ns_lengths.begin(),
1152 c_gs_ms_ns_lengths.end(),
1153 c_lengths.begin(),
1154 [](index_t i) { return i; });
1155 std::transform(c_gs_ms_ns_strides.begin(),
1156 c_gs_ms_ns_strides.end(),
1157 c_strides.begin(),
1158 [](index_t i) { return i; });
1159 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
1160 static_cast<const B0DataType*>(p_b0),
1161 static_cast<const B1DataType*>(p_b1),
1162 static_cast<CDataType*>(p_c),
1163 p_acc0_biases,
1164 p_acc1_biases,
1165 a_lengths,
1166 a_strides,
1167 b0_lengths,
1168 b0_strides,
1169 b1_lengths,
1170 b1_strides,
1171 c_lengths,
1172 c_strides,
1173 acc0_biases_gs_ms_ls_lengths,
1174 acc0_biases_gs_ms_ls_strides,
1175 acc1_biases_gs_ms_ns_lengths,
1176 acc1_biases_gs_ms_ns_strides,
1177 1,
1178 1,
1179 a_element_op,
1180 b0_element_op,
1181 acc_element_op,
1182 b1_element_op,
1183 c_element_op);
1184 }
1185
1186 static auto MakeInvoker() { return Invoker{}; }
1187
1188 // polymorphic
1189 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1190 {
1191 return std::make_unique<Invoker>(Invoker{});
1192 }
1193
1194 // polymorphic
1195 std::string GetTypeString() const override
1196 {
1197 auto str = std::stringstream();
1198
1199 std::map<LoopScheduler, std::string> LoopSchedToString{
1200 {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
1201
1202 std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
1203 {PipelineVersion::v2, "v2"}};
1204
1205 // clang-format off
1206 str << "DeviceMultiQueryAttentionForward_Wmma"
1207 << "<"
1208 << BlockSize << ", "
1209 << MPerBlock << ", "
1210 << LPerBlock << ", "
1211 << KPerBlock << ", "
1212 << AK1 << ", "
1213 << BK1 << ", "
1214 << MPerBlock << ", "
1215 << NPerBlock << ", "
1216 << LTilePerBlock << ", "
1217 << L1 << ", "
1218 << getGemmSpecializationString(GemmSpec) << ", "
1219 << "ASpec" << getTensorSpecializationString(ASpec) << ", "
1220 << "B0Spec" << getTensorSpecializationString(B0Spec) << ", "
1221 << "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
1222 << "CSpec" << getTensorSpecializationString(CSpec) << ", "
1223 << getMaskingSpecializationString(MaskingSpec)
1224 << ">"
1225 << " AEnableLds: "
1226 << AEnableLds << ", "
1227 << "B0EnableLds: "
1228 << B0EnableLds << ", "
1229 << "B1EnableLds: "
1230 << B1EnableLds << ", "
1231 << "NumPrefetch: "
1232 << NumPrefetch << ", "
1233 << "LoopScheduler: "
1234 << LoopSchedToString[LoopSched] << ", "
1235 << "PipelineVersion: "
1236 << PipelineVersionToString[PipelineVer];
1237 // clang-format on
1238
1239 return str.str();
1240 }
1241};
1242
1243} // namespace device
1244} // namespace tensor_operation
1245} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
std::string getMaskingSpecializationString(const MaskingSpecialization &s)
Definition masking_specialization.hpp:17
MaskingSpecialization
Definition masking_specialization.hpp:11
@ MaskDisabled
Definition masking_specialization.hpp:12
@ MaskOutUpperTriangle
Definition masking_specialization.hpp:13
TensorSpecialization
Definition tensor_specialization.hpp:11
GemmSpecialization
Definition gemm_specialization.hpp:11
__global__ void kernel_multi_query_attention_wmma(const ADataType *__restrict__ p_a_grid, const B0DataType *__restrict__ p_b0_grid, const B1DataType *__restrict__ p_b1_grid, CDataType *__restrict__ p_c_grid, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute)
Definition device_multi_query_attention_forward_wmma.hpp:49
std::string getTensorSpecializationString(const TensorSpecialization &s)
Definition tensor_specialization.hpp:16
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
integral_constant< index_t, N > Number
Definition number.hpp:12
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
constexpr bool is_same_v
Definition type.hpp:283
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
int64_t long_index_t
Definition ck.hpp:300
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:93
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::DefaultBlock2CTileMap
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:682
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::MakeDefaultBlock2CTileMap
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:672
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:679
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:653
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::CalculateHasMainKBlockLoop
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:645
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::CheckValidity
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc &a_grid_desc, const B0GridDesc &b0_grid_desc, const B1GridDesc &b1_grid_desc, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:511
Definition utility/sequence.hpp:43
Definition utility/integral_constant.hpp:20
Definition transform_contraction_to_gemm_arraybase.hpp:122
__host__ static __device__ constexpr auto MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k, const Number &BK1)
Definition transform_contraction_to_gemm_arraybase.hpp:245
__host__ static __device__ auto MakeCGridDescriptor_G_M_N(const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:375
__host__ static __device__ constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k, const Number &AK1)
Definition transform_contraction_to_gemm_arraybase.hpp:172
__host__ static __device__ auto MakeB1GridDescriptor_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:307
__host__ static __device__ auto MakeB1GridDescriptor_G_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:301
__host__ static __device__ constexpr auto MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1(const BGridDesc_L_K &b_grid_desc_l_k, const WmmaK &, const LRepeat &, const LWaves &, const LPerWmma &, const BK1 &)
Definition transform_contraction_to_gemm_arraybase.hpp:266
__host__ static __device__ auto MakeB0GridDescriptor_G_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:228
__host__ static __device__ constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K &b1_grid_desc_n_k, const Number &B1K1)
Definition transform_contraction_to_gemm_arraybase.hpp:318
__host__ static __device__ auto MakeAGridDescriptor_G_M_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:156
__host__ static __device__ constexpr auto MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1(const AGridDesc_M_K &a_grid_desc_m_k, const WmmaK &, const MRepeat &, const MWaves &, const MPerWmma &, const AK1 &)
Definition transform_contraction_to_gemm_arraybase.hpp:193
__host__ static __device__ auto MakeCGridDescriptor_M_N(const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:381
__host__ static __device__ constexpr auto MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1(const BGridDesc_N_L &b_grid_desc_n_l, const WmmaL &, const NRepeat &, const NWaves &, const NPerWmma &, const BL1 &)
Definition transform_contraction_to_gemm_arraybase.hpp:340
__host__ static __device__ auto MakeAGridDescriptor_M_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:162
__host__ static __device__ auto MakeB0GridDescriptor_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:234
Definition device_base.hpp:197
Definition masking_specialization.hpp:57
Definition device_batched_gemm_softmax_gemm_permute.hpp:34
Definition device_multi_query_attention_forward_wmma.hpp:749
ComputeBasePtrOfStridedBatch compute_ptr_offset_of_batch_
Definition device_multi_query_attention_forward_wmma.hpp:879
B0GridDesc_G_L_K b0_grid_desc_g_l_k_
Definition device_multi_query_attention_forward_wmma.hpp:849
B1ElementwiseOperation b1_element_op_
Definition device_multi_query_attention_forward_wmma.hpp:863
const B0DataType * p_b0_grid_
Definition device_multi_query_attention_forward_wmma.hpp:838
AccElementwiseOperation acc_element_op_
Definition device_multi_query_attention_forward_wmma.hpp:862
Argument(const ADataType *p_a_grid, const B0DataType *p_b0_grid, const B1DataType *p_b1_grid, CDataType *p_c_grid, const std::array< void *, NumAcc0Bias > p_acc0_biases, const std::array< void *, NumAcc1Bias > p_acc1_biases, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_strides, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_strides, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_ns_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_ns_strides, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_lengths, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_strides, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_strides, const index_t M01, const index_t N01, AElementwiseOperation a_element_op, B0ElementwiseOperation b0_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op)
Definition device_multi_query_attention_forward_wmma.hpp:750
const B1DataType * p_b1_grid_
Definition device_multi_query_attention_forward_wmma.hpp:839
B0GridDesc b0_grid_desc
Definition device_multi_query_attention_forward_wmma.hpp:844
B1GridDesc b1_grid_desc
Definition device_multi_query_attention_forward_wmma.hpp:845
CGridDesc_M_N c_grid_desc_m_n_
Definition device_multi_query_attention_forward_wmma.hpp:846
std::array< index_t, NumDimG+NumDimM+NumDimN > b0_lz_kz_strides_
Definition device_multi_query_attention_forward_wmma.hpp:873
const ADataType * p_a_grid_
Definition device_multi_query_attention_forward_wmma.hpp:837
GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_multi_query_attention_forward_wmma.hpp:854
CGridDesc_G_M_N c_grid_desc_g_m_n_
Definition device_multi_query_attention_forward_wmma.hpp:851
index_t batch_count_
Definition device_multi_query_attention_forward_wmma.hpp:877
AGridDesc a_grid_desc
Definition device_multi_query_attention_forward_wmma.hpp:843
C0MatrixMask c0_matrix_mask_
Definition device_multi_query_attention_forward_wmma.hpp:867
GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_multi_query_attention_forward_wmma.hpp:857
std::array< index_t, NumDimG+NumDimM+NumDimN > a_mz_kz_strides_
Definition device_multi_query_attention_forward_wmma.hpp:872
std::array< index_t, NumDimG+NumDimM+NumDimN > c_mz_nz_strides_
Definition device_multi_query_attention_forward_wmma.hpp:875
B1GridDesc_G_N_L b1_grid_desc_g_n_l_
Definition device_multi_query_attention_forward_wmma.hpp:850
std::array< index_t, NumDimG+NumDimM+NumDimN > raw_lengths_mz_lz_kz_nz_
Definition device_multi_query_attention_forward_wmma.hpp:871
CDataType * p_c_grid_
Definition device_multi_query_attention_forward_wmma.hpp:840
B0ElementwiseOperation b0_element_op_
Definition device_multi_query_attention_forward_wmma.hpp:861
CElementwiseOperation c_element_op_
Definition device_multi_query_attention_forward_wmma.hpp:864
AElementwiseOperation a_element_op_
Definition device_multi_query_attention_forward_wmma.hpp:860
AGridDesc_G_M_K a_grid_desc_g_m_k_
Definition device_multi_query_attention_forward_wmma.hpp:848
std::array< index_t, NumDimG+NumDimM+NumDimN > b1_nz_lz_strides_
Definition device_multi_query_attention_forward_wmma.hpp:874
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
Definition device_multi_query_attention_forward_wmma.hpp:437
__host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K &a_grid_desc_g_m_k, const B0GridDesc_G_L_K &b0_grid_desc_g_l_k, const B1GridDesc_G_N_L &b1_grid_desc_g_n_l, const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition device_multi_query_attention_forward_wmma.hpp:416
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
Definition device_multi_query_attention_forward_wmma.hpp:432
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
Definition device_multi_query_attention_forward_wmma.hpp:442
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
Definition device_multi_query_attention_forward_wmma.hpp:427
Definition device_multi_query_attention_forward_wmma.hpp:883
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_multi_query_attention_forward_wmma.hpp:886
DeviceOp::RawArg Argument
Definition device_multi_query_attention_forward_wmma.hpp:884
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_multi_query_attention_forward_wmma.hpp:939
Definition device_multi_query_attention_forward_wmma.hpp:531
index_t G1_
Definition device_multi_query_attention_forward_wmma.hpp:572
const B0DataType * p_b0_grid_
Definition device_multi_query_attention_forward_wmma.hpp:562
float alpha_
Definition device_multi_query_attention_forward_wmma.hpp:573
const ADataType * p_a_grid_
Definition device_multi_query_attention_forward_wmma.hpp:561
bool input_permute_
Definition device_multi_query_attention_forward_wmma.hpp:574
bool output_permute_
Definition device_multi_query_attention_forward_wmma.hpp:575
const B1DataType * p_b1_grid_
Definition device_multi_query_attention_forward_wmma.hpp:563
index_t K_
Definition device_multi_query_attention_forward_wmma.hpp:569
index_t M_
Definition device_multi_query_attention_forward_wmma.hpp:567
CDataType * p_c_grid_
Definition device_multi_query_attention_forward_wmma.hpp:564
index_t N_
Definition device_multi_query_attention_forward_wmma.hpp:568
index_t O_
Definition device_multi_query_attention_forward_wmma.hpp:570
RawArg(const ADataType *p_a_grid, const B0DataType *p_b0_grid, const B1DataType *p_b1_grid, CDataType *p_c_grid, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute)
Definition device_multi_query_attention_forward_wmma.hpp:532
index_t G0_
Definition device_multi_query_attention_forward_wmma.hpp:571
Definition device_multi_query_attention_forward_wmma.hpp:264
decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})) AGridDesc_G_M_K
Definition device_multi_query_attention_forward_wmma.hpp:396
static constexpr auto I5
Definition device_multi_query_attention_forward_wmma.hpp:288
static constexpr auto WmmaK
Definition device_multi_query_attention_forward_wmma.hpp:291
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b0, const void *p_b1, void *p_c, const std::array< void *, NumAcc0Bias > p_acc0_biases, const std::array< void *, NumAcc1Bias > p_acc1_biases, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b0_gs_ls_ks_lengths, const std::vector< index_t > &b0_gs_ls_ks_strides, const std::vector< index_t > &b1_gs_ns_ls_lengths, const std::vector< index_t > &b1_gs_ns_ls_strides, const std::vector< index_t > &c_gs_ms_ns_lengths, const std::vector< index_t > &c_gs_ms_ns_strides, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_lengths, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_strides, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_strides, AElementwiseOperation a_element_op, B0ElementwiseOperation b0_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op) override
Definition device_multi_query_attention_forward_wmma.hpp:1094
static constexpr index_t NumAcc0Bias
Definition device_multi_query_attention_forward_wmma.hpp:268
__host__ static __device__ auto MakeB1GridDescriptor(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_strides_vec)
Definition device_multi_query_attention_forward_wmma.hpp:367
static constexpr auto NWaves
Definition device_multi_query_attention_forward_wmma.hpp:295
__host__ static __device__ auto MakeAGridDescriptor(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition device_multi_query_attention_forward_wmma.hpp:318
static constexpr auto B0EnableLds_auto
Definition device_multi_query_attention_forward_wmma.hpp:298
static constexpr index_t NumDimGemm1K
Definition device_multi_query_attention_forward_wmma.hpp:279
decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})) B1GridDesc_G_N_L
Definition device_multi_query_attention_forward_wmma.hpp:398
static constexpr bool IsValidCompilationParameter()
Definition device_multi_query_attention_forward_wmma.hpp:946
GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer > GridwiseOp
Definition device_multi_query_attention_forward_wmma.hpp:455
DeviceMultiQueryAttentionForward_Wmma DeviceOp
Definition device_multi_query_attention_forward_wmma.hpp:281
static constexpr index_t NumDimGemm0N
Definition device_multi_query_attention_forward_wmma.hpp:275
static constexpr index_t NumDimGemm1N
Definition device_multi_query_attention_forward_wmma.hpp:278
static constexpr auto AEnableLds
Definition device_multi_query_attention_forward_wmma.hpp:305
static constexpr auto B0EnableLds
Definition device_multi_query_attention_forward_wmma.hpp:306
decltype(MakeB1GridDescriptor({}, {})) B1GridDesc
Definition device_multi_query_attention_forward_wmma.hpp:394
static bool IsSupportedArgument(const RawArg &arg)
Definition device_multi_query_attention_forward_wmma.hpp:596
decltype(MakeB0GridDescriptor({}, {})) B0GridDesc
Definition device_multi_query_attention_forward_wmma.hpp:393
static constexpr index_t NumDimGemm1M
Definition device_multi_query_attention_forward_wmma.hpp:277
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_multi_query_attention_forward_wmma.hpp:742
static constexpr auto I2
Definition device_multi_query_attention_forward_wmma.hpp:285
decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})) CGridDesc_G_M_N
Definition device_multi_query_attention_forward_wmma.hpp:399
static constexpr auto LWaves
Definition device_multi_query_attention_forward_wmma.hpp:294
static constexpr auto B1EnableLds
Definition device_multi_query_attention_forward_wmma.hpp:307
static constexpr auto I1
Definition device_multi_query_attention_forward_wmma.hpp:284
static constexpr auto I3
Definition device_multi_query_attention_forward_wmma.hpp:286
decltype(MakeAGridDescriptor({}, {})) AGridDesc
Definition device_multi_query_attention_forward_wmma.hpp:392
decltype(Transform::MakeCGridDescriptor_M_N({}, {})) CGridDesc_M_N
Definition device_multi_query_attention_forward_wmma.hpp:395
static constexpr auto I6
Definition device_multi_query_attention_forward_wmma.hpp:289
static constexpr auto MWaves
Definition device_multi_query_attention_forward_wmma.hpp:293
static constexpr auto AEnableLds_manu
Definition device_multi_query_attention_forward_wmma.hpp:301
static constexpr auto AEnableLds_auto
Definition device_multi_query_attention_forward_wmma.hpp:297
static constexpr auto B0EnableLds_manu
Definition device_multi_query_attention_forward_wmma.hpp:302
__host__ static __device__ auto MakeB0GridDescriptor(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_strides_vec)
Definition device_multi_query_attention_forward_wmma.hpp:342
static constexpr auto I0
Definition device_multi_query_attention_forward_wmma.hpp:283
TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< Sequence< NumDimG, NumDimM, NumDimL, NumDimK, NumDimN >, Sequence< MPerBlock, LPerBlock, KPerBlock, NPerBlock >, GemmSpec, ASpec, B0Spec, B1Spec, CSpec > Transform
Definition device_multi_query_attention_forward_wmma.hpp:309
static constexpr index_t NumDimGemm0M
Definition device_multi_query_attention_forward_wmma.hpp:274
__host__ __device__ static constexpr auto make_MaskOutPredicate()
Definition device_multi_query_attention_forward_wmma.hpp:401
static constexpr index_t NumDimGemm0K
Definition device_multi_query_attention_forward_wmma.hpp:276
std::string GetTypeString() const override
Definition device_multi_query_attention_forward_wmma.hpp:1195
static auto MakeInvoker()
Definition device_multi_query_attention_forward_wmma.hpp:1186
static auto MakeArgument(const ADataType *p_a, const B0DataType *p_b0, const B1DataType *p_b1, CDataType *p_c, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute)
Definition device_multi_query_attention_forward_wmma.hpp:578
static constexpr auto B1EnableLds_auto
Definition device_multi_query_attention_forward_wmma.hpp:299
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_multi_query_attention_forward_wmma.hpp:1189
C0MatrixMask_impl< decltype(make_MaskOutPredicate())> C0MatrixMask
Definition device_multi_query_attention_forward_wmma.hpp:412
static constexpr index_t NumAcc1Bias
Definition device_multi_query_attention_forward_wmma.hpp:269
decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})) B0GridDesc_G_L_K
Definition device_multi_query_attention_forward_wmma.hpp:397
static constexpr auto B1EnableLds_manu
Definition device_multi_query_attention_forward_wmma.hpp:303
static constexpr auto I4
Definition device_multi_query_attention_forward_wmma.hpp:287
Definition masking_specialization.hpp:29
Definition masking_specialization.hpp:43