fused_moegemm_pipeline_flatmm_ex.hpp Source File

fused_moegemm_pipeline_flatmm_ex.hpp Source File#

Composable Kernel: fused_moegemm_pipeline_flatmm_ex.hpp Source File
fused_moegemm_pipeline_flatmm_ex.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
12/*
13This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
14we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
15
16 <----- gemm-N ------>
17 +----+----+----+----+
18 | w0 | w1 | w2 | w3 | gemm-m
19 +----+----+----+----+
20*/
21template <typename Problem_, typename Policy_ = FusedMoeGemmPipelineFlatmmPolicy>
23{
26
27 using BlockShape = typename Problem::BlockShape; // this is FusedMoeGemmShape
28
29 using ADataType = typename Problem::ADataType;
30 using GDataType = typename Problem::GDataType;
31 using DDataType = typename Problem::DDataType;
32 using AccDataType = typename Problem::AccDataType;
33 using ODataType = typename Problem::ODataType;
34 using AScaleDataType = typename Problem::AScaleDataType;
35 using GScaleDataType = typename Problem::GScaleDataType;
36 using DScaleDataType = typename Problem::DScaleDataType;
37 using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType;
38 using TopkWeightDataType = typename Problem::TopkWeightDataType;
39 using IndexDataType = typename Problem::IndexDataType;
40 using YDataType = typename Problem::YDataType;
41
42 using Traits = typename Problem::Traits;
43
44 static constexpr bool IsGateOnly = Traits::IsGateOnly;
45 static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
46 static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
47 static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
48
49 static constexpr index_t kAlignmentA = Policy::template GetAlignment_A<Problem>();
50 static constexpr index_t kAlignmentG = Policy::template GetAlignment_G<Problem>();
51 static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
52 static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
53
58
59 static constexpr index_t kBlockPerCu = []() {
60 if constexpr(Problem::kBlockPerCu != -1)
61 return Problem::kBlockPerCu;
62 else
63 {
64 // minimize occupancy
65 return 2;
66 }
67 }();
68
69 static constexpr const char* name = "fused_moe_flatmm";
70
71 // TODO: there are multiple buffers
73 {
74 return Policy::template GetSmemSize_A<Problem>();
75 }
76
78 {
79 return Policy::template GetSmemSize<Problem>();
80 }
81
82 // this is the thread-offset along row/col
84 {
85 constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
86 const auto a_coord = a_dist.calculate_index();
87 return a_coord;
88 }
89
90 // this is the thread-offset along row/col
92 {
93 constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
94 const auto o_coord = o_dist.calculate_index();
95 return o_coord;
96 }
97
98 template <typename AWindow, typename GWindow, typename DWindow, typename OWindow>
99 CK_TILE_DEVICE auto operator()(const AWindow& a_window_,
100 const GWindow& g_window_,
101 const DWindow& d_window_,
102 OWindow& o_window_,
103 TopkWeightDataType /*topk_weight*/,
104 CK_TILE_LDS_ADDR void* smem,
105 index_t hidden_size,
106 index_t intermediate_size)
107 {
108 _Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wc++20-extensions\"");
109 constexpr auto NEG1 = number<-1>{};
110 constexpr auto I0 = number<0>{};
111 constexpr auto I1 = number<1>{};
112 constexpr auto TRUE = bool_constant<true>{};
113 constexpr auto FALSE = bool_constant<false>{};
114
115 CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
116 CK_TILE_LDS_ADDR ADataType* smem_1 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(
117 reinterpret_cast<CK_TILE_LDS_ADDR char*>(smem) +
118 Policy::template GetSmemSize_A<Problem>());
119
120 auto g_view = g_window_.get_bottom_tensor_view();
121
122 auto u_view = [&]() {
123 if constexpr(IsGateOnly)
124 {
125 return g_view;
126 }
127 else
128 {
129 index_t nr_0 = intermediate_size / BlockShape::Block_Nr0;
130 index_t kr_0 = hidden_size / BlockShape::Block_Kr0;
131
132 const GDataType* g_ptr =
133 g_window_.get_bottom_tensor_view().get_buffer_view().p_data_;
134 const GDataType* u_ptr = g_ptr + (nr_0 / 2) * kr_0 * number<BlockShape::Block_W0>{};
135
137 u_ptr,
139 make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
141 number<1>{});
142 const auto u_view_1_ =
143 pad_tensor_view(u_view_,
148 return u_view_1_;
149 }
150 }();
151
152 auto a_win = make_tile_window_linear(
153 a_window_, Policy::template MakeGlobalTileDistribution_A<Problem>());
154 auto g_win =
155 make_tile_window_linear(g_window_,
156 Policy::template MakeGlobalTileDistribution_G<Problem>(),
158 auto d_win =
159 make_tile_window_linear(d_window_,
160 Policy::template MakeGlobalTileDistribution_D<Problem>(),
162 auto o_win = make_tile_window_linear(
163 o_window_, Policy::template MakeGlobalTileDistribution_O<Problem>());
164
165 using g_thread_type = decltype(load_tile(g_win));
166 using d_thread_type = decltype(load_tile(d_win));
167
168 using WarpGemm0 = decltype(Policy::template GetWarpGemm0<Problem>());
169 using WarpGemm1 = decltype(Policy::template GetWarpGemm1<Problem>());
170 auto warp_gemm_0 = WarpGemm0{};
171 auto warp_gemm_1 = WarpGemm1{};
172
173 // issues_warps_lanes
174 auto a_sst_win0 =
176 smem_0, Policy::template MakeLdsStoreDesc_A<Problem>()),
177 Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
178 {0, 0, 0});
179
180 auto a_sst_win1 =
182 smem_1, Policy::template MakeLdsStoreDesc_A<Problem>()),
183 Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
184 {0, 0, 0});
185 // m*k
186 auto a_sld_win0 = [&]() {
187 using WG = WarpGemm0;
188 constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
196 constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
197 a_outer_dstr_enc, typename WG::AWarpDstrEncoding{});
200 smem_0, Policy::template MakeLdsLoadDesc_A<Problem>()),
201 Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
202 {0, 0},
203 make_static_tile_distribution(a_block_dstr_encode));
204 }();
205
206 // m*k
207 auto a_sld_win1 = [&]() {
208 using WG = WarpGemm0;
209 constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
217 constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
218 a_outer_dstr_enc, typename WG::AWarpDstrEncoding{});
221 smem_1, Policy::template MakeLdsLoadDesc_A<Problem>()),
222 Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
223 {0, 0},
224 make_static_tile_distribution(a_block_dstr_encode));
225 }();
226
227 auto bridge_sst_win = [&]() {
228 return make_tile_window(
230 reinterpret_cast<YDataType*>(smem),
231 Policy::template MakeBridgeLdsStoreDesc<Problem>()),
232 Policy::template MakeBridgeLdsStoreDesc<Problem>().get_lengths(),
233 {0, 0});
234 }();
235
236 auto bridge_sld_win = [&]() {
239 reinterpret_cast<YDataType*>(smem),
240 Policy::template MakeBridgeLdsLoadDesc<Problem>()),
241 Policy::template MakeBridgeLdsLoadDesc<Problem>().get_lengths(),
242 {0, 0},
243 Policy::template MakeYTileDistribution<Problem>());
244 }();
245
246 // also OK with C array, 2 register buffer
248
249 constexpr auto issues_a = number<a_win.get_num_of_access()>{};
250 constexpr auto issues_g = number<g_win.get_num_of_access()>{};
251 // constexpr auto issues_d = number<d_win.get_num_of_access()>{};
252 // constexpr auto issues_o = number<o_win.get_num_of_access()>{};
253 constexpr auto issues_gemm0 =
254 number<BlockShape::Repeat_M0 * BlockShape::Repeat_N0 * BlockShape::Repeat_K0 *
255 warp_gemm_0.get_num_of_access()>{};
256 constexpr auto issues_gemm1 =
257 number<BlockShape::Repeat_M1 * BlockShape::Repeat_N1 * BlockShape::Repeat_K1 *
258 warp_gemm_1.get_num_of_access()>{};
259 // constexpr auto issues_sld_a = number<a_sld_win0.get_num_of_access()>{};
260
261 const index_t num_blocks_k0 =
262 (hidden_size + BlockShape::Block_K0 - 1) / BlockShape::Block_K0;
263 const index_t num_blocks_n1 =
264 (hidden_size + BlockShape::Block_N1 - 1) / BlockShape::Block_N1;
265
266 using a_thread_type = decltype(load_tile(a_sld_win0));
268
269 auto gld_a = [&]<typename PreNop = bool_constant<false>>(
270 auto& a_store_, auto i_access, PreNop = {}) {
271 async_load_tile_raw(a_store_, a_win, i_access, PreNop{});
272 };
273 auto move_a = [&]() {
275 };
276 auto sld_a = [&](auto& a_, auto& win_, auto i_access) {
277 load_tile_raw(a_, win_, i_access);
278 };
279
280 auto gld_g =
281 [&]<typename PreNop = bool_constant<false>>(auto& g_, auto i_access, PreNop = {}) {
282 if constexpr(IsGateOnly)
283 {
284 // TODO: hack!
285 if constexpr(i_access.value == 0)
286 {
287 g_win.bottom_tensor_view_ = g_view;
288 }
289 else if constexpr(i_access.value == issues_g / 2)
290 {
291 g_win.bottom_tensor_view_ = u_view;
292 }
293 }
294 load_tile_raw(g_, g_win, i_access, FALSE, PreNop{});
295 };
296 auto move_g = [&]() {
298 };
300
301 auto gld_d =
302 [&]<typename PreNop = bool_constant<false>>(auto& d_, auto i_access, PreNop = {}) {
303 load_tile_raw(d_, d_win, i_access, FALSE, PreNop{});
304 };
305 auto move_d = [&]() {
306 // d move along gemm-n
308 };
309
310 auto atomic_add_o =
311 [&]<typename PreNop = bool_constant<false>>(auto& o_, auto i_access, PreNop = {}) {
312 update_tile_raw(o_win, o_, i_access, TRUE, PreNop{});
313 };
314
315 auto acc_0 = Policy::template MakeCBlockTile_Gemm0<Problem>();
316 auto acc_1s = generate_tuple(
317 [&](auto) { return Policy::template MakeCBlockTile_Gemm1<Problem>(); }, number<2>{});
318
319 // clang-format off
320 auto gemm_0 = [&]<typename PostNop = bool_constant<false>>
321 (auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) {
322 using WarpGemm = remove_cvref_t<decltype(warp_gemm_0)>;
323
324 constexpr auto repeat_sub = WarpGemm::get_num_of_access();
325 constexpr auto repeat_m = BlockShape::Repeat_M0;
326 // constexpr auto repeat_n = BlockShape::Repeat_N0;
327 constexpr auto repeat_k = BlockShape::Repeat_K0;
328 // loop order n->m->k
329 constexpr auto i_sub = i_access % repeat_sub;
330 constexpr auto i_k = (i_access / repeat_sub) % repeat_k;
331 constexpr auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
332 constexpr auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
333
334 using AWarpTensor = typename WarpGemm::AWarpTensor;
335 using BWarpTensor = typename WarpGemm::BWarpTensor;
336 using CWarpTensor = typename WarpGemm::CWarpTensor;
337 using AWarpDstr = typename WarpGemm::AWarpDstr;
338 using BWarpDstr = typename WarpGemm::BWarpDstr;
339 using CWarpDstr = typename WarpGemm::CWarpDstr;
340
341 constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
342 constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
343 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
344
345 constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
346 constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
347 constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
348
349 AWarpTensor w_a;
350 w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
351 merge_sequences(sequence<i_m, i_k>{}, a_warp_y_index_zeros),
352 merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
353
354 BWarpTensor w_b;
355 w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data(
356 merge_sequences(sequence<i_n, i_k>{}, b_warp_y_index_zeros),
357 merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
358
359 CWarpTensor w_c;
360 w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data(
361 merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
362 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
363
364 warp_gemm_0(w_c, w_a, w_b, number<i_sub>{}, PostNop{});
365
366 t_c.set_y_sliced_thread_data(
367 merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
368 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
369 w_c.get_thread_buffer());
370 };
371 // clang-format on
372
373 // clang-format off
374 auto gemm_1 = [&]<typename PostNop = bool_constant<false>>
375 (auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) {
376 using WarpGemm = remove_cvref_t<decltype(warp_gemm_1)>;
377
378 constexpr auto repeat_sub = WarpGemm::get_num_of_access();
379 constexpr auto repeat_m = BlockShape::Repeat_M0;
380 // constexpr auto repeat_n = BlockShape::Repeat_N0;
381 constexpr auto repeat_k = BlockShape::Repeat_K0;
382 // loop order n->m->k
383 constexpr auto i_sub = i_access % repeat_sub;
384 constexpr auto i_k = (i_access / repeat_sub) % repeat_k;
385 constexpr auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
386 constexpr auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
387
388 using AWarpTensor = typename WarpGemm::AWarpTensor;
389 using BWarpTensor = typename WarpGemm::BWarpTensor;
390 using CWarpTensor = typename WarpGemm::CWarpTensor;
391 using AWarpDstr = typename WarpGemm::AWarpDstr;
392 using BWarpDstr = typename WarpGemm::BWarpDstr;
393 using CWarpDstr = typename WarpGemm::CWarpDstr;
394
395 constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
396 constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
397 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
398
399 constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
400 constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
401 constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
402
403 AWarpTensor w_a;
404 w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
405 merge_sequences(sequence<i_m, i_k>{}, a_warp_y_index_zeros),
406 merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
407
408 BWarpTensor w_b;
409 w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data(
410 merge_sequences(sequence<i_n, i_k>{}, b_warp_y_index_zeros),
411 merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
412
413 CWarpTensor w_c;
414 w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data(
415 merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
416 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
417
418 warp_gemm_1(w_c, w_a, w_b, number<i_sub>{}, PostNop{});
419
420 t_c.set_y_sliced_thread_data(
421 merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
422 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
423 w_c.get_thread_buffer());
424 };
425 // clang-format on
426 _Pragma("clang diagnostic pop");
427
428 // this gemm pipeline is designed with assumption that issues of buffer-load/ds_read can
429 // be hide under mfma. In other words, issues of mfma is >= memory this is true if we
430 // pre-shuffle B matrix, and A matrix is relatively small we prefer use multiple mfma
431 // paired with 1 buffer-load B matrix, to get max throughput of buffer_load. and by
432 // preshuffle, we always pack to dwordx4 load, and this will already extend to multiple
433 // mfma but that is already consumed inside warpgemm-impl. So indeed how many extra
434 // mfma(that can reuse the B matrix) only affected by M repeat.
435 auto pipeline_gemm0 = [&]() {
436 constexpr index_t total_loops = issues_gemm0;
437 constexpr auto sr = Policy::template GetSequencer_0<Problem>();
438 static_assert(sr.size() == total_loops);
439
440 constexpr auto c_sld_a_0 = MAKE_SC();
441 constexpr auto c_gld_a_0 = MAKE_SC();
442 constexpr auto c_gld_b_0 = MAKE_SC();
443 // compute buffer 1
444 static_for<0, total_loops, 1>{}([&](auto i_issue) {
445 gemm_0(acc_0, as[I0], gs[I0], i_issue);
446 constexpr index_t slot = sr.at(i_issue);
447
448 if constexpr(slot & SLD_A)
449 sld_a(as[I1], a_sld_win1, number<NEXT_SCI(c_sld_a_0, i_issue)>{});
450 if constexpr(slot & GLD_A)
451 gld_a(a_sst_win0, number<NEXT_SCI(c_gld_a_0, i_issue)>{});
452 if constexpr(slot & GLD_B)
453 gld_g(gs[I0], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
454 });
455 move_g();
456 move_a();
457 block_sync_load_raw(issues_a + issues_g);
459
460 constexpr auto c_sld_a_1 = MAKE_SC();
461 constexpr auto c_gld_a_1 = MAKE_SC();
462 constexpr auto c_gld_b_1 = MAKE_SC();
463
464 // compute buffer 1
465 static_for<0, total_loops, 1>{}([&](auto i_issue) {
466 gemm_0(acc_0, as[I1], gs[I1], i_issue);
467 constexpr index_t slot = sr.at(i_issue);
468
469 if constexpr(slot & SLD_A)
470 sld_a(as[I0], a_sld_win0, number<NEXT_SCI(c_sld_a_1, i_issue)>{});
471 if constexpr(slot & GLD_A)
472 gld_a(a_sst_win1, number<NEXT_SCI(c_gld_a_1, i_issue)>{});
473 if constexpr(slot & GLD_B)
474 gld_g(gs[I1], number<NEXT_SCI(c_gld_b_1, i_issue)>{});
475 });
476 move_g();
477 move_a();
478 block_sync_load_raw(issues_a + issues_g);
480 };
481
482 auto pipeline_gemm0_tail = [&]() {
483 constexpr index_t total_loops = issues_gemm0;
484 constexpr auto sr = Policy::template GetSequencer_0<Problem>();
485 static_assert(sr.size() == total_loops);
486
487 constexpr auto c_gld_b_0 = MAKE_SC();
488
489 // compute buffer 0
490 static_for<0, total_loops, 1>{}([&](auto i_issue) {
491 gemm_0(acc_0, as[I0], gs[I0], i_issue);
492 constexpr index_t slot = sr.at(i_issue);
493
494 if constexpr(slot & GLD_B)
495 gld_g(gs[I1], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
496 });
497
498 block_sync_load_raw(issues_g);
499 sld_a(as[I1], a_sld_win1, NEG1);
500
501 // compute buffer 1
502 static_for<0, total_loops, 1>{}([&](auto i_issue) {
503 constexpr auto last_nop = [&]() {
504 if constexpr(i_issue == (total_loops - 1))
505 return TRUE;
506 else
507 return FALSE;
508 }();
509 gemm_0(acc_0, as[I1], gs[I1], i_issue, last_nop); // last gemm has nop
510 });
511 };
512
513 auto y = Policy::template MakeYBlockTile<Problem>();
514
515 auto pipeline_bridge = [&]() {
516 // cast to Y data
517 auto y_pre = cast_tile<YDataType>(acc_0);
518 store_tile(bridge_sst_win, y_pre);
519 clear_tile(acc_1s(I0));
520 // wave_barrier();
521 load_tile(y, bridge_sld_win);
522 clear_tile(acc_1s(I1));
523 };
524
525 // note, gemm-1 start from idx-1 to N-2 (0, 1, 2....N-1)
526 auto pipeline_gemm1 = [&]() {
527 constexpr index_t total_loops = issues_gemm1;
528 constexpr auto sr = Policy::template GetSequencer_1<Problem>();
529 static_assert(sr.size() == total_loops);
530
531 constexpr auto c_gld_b_0 = MAKE_SC();
532 constexpr auto c_gst_o_0 = MAKE_SC();
533 constexpr auto c_gld_b_1 = MAKE_SC();
534 constexpr auto c_gst_o_1 = MAKE_SC();
535
536 // compute buffer 0
537 static_for<0, total_loops, 1>{}([&](auto i_issue) {
538 gemm_1(acc_1s[I1], y, ds[I1], i_issue);
539 constexpr index_t slot = sr.at(i_issue);
540 if constexpr(slot & GLD_B)
541 gld_d(ds[I0], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
542
543 if constexpr(slot & GST_O)
544 {
545 auto out = cast_tile<ODataType>(acc_1s[I0]);
546 atomic_add_o(out, number<NEXT_SCI(c_gst_o_0, i_issue)>{});
547 }
548 });
549 move_d();
550 // move_o();
551
552 // compute buffer 1
553 static_for<0, total_loops, 1>{}([&](auto i_issue) {
554 gemm_1(acc_1s[I0], y, ds[I0], i_issue);
555 constexpr index_t slot = sr.at(i_issue);
556 if constexpr(slot & GLD_B)
557 gld_d(ds[I1], number<NEXT_SCI(c_gld_b_1, i_issue)>{});
558
559 if constexpr(slot & GST_O)
560 {
561 auto out = cast_tile<ODataType>(acc_1s[I1]);
562 atomic_add_o(out, number<NEXT_SCI(c_gst_o_1, i_issue)>{});
563 }
564 });
565 move_d();
566 };
567
568 auto pipeline_gemm1_head = [&]() {
569 constexpr index_t total_loops = issues_gemm1;
570 constexpr auto sr = Policy::template GetSequencer_1<Problem>();
571 static_assert(sr.size() == total_loops);
572
573 constexpr auto c_gld_b_0 = MAKE_SC();
574
575 // compute buffer 0
576 static_for<0, total_loops, 1>{}([&](auto i_issue) {
577 gemm_1(acc_1s[I0], y, ds[I0], i_issue);
578 constexpr index_t slot = sr.at(i_issue);
579 if constexpr(slot & GLD_B)
580 gld_d(ds[I1], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
581 });
582 move_d();
583 };
584 auto pipeline_gemm1_tail = [&]() {
585 constexpr index_t total_loops = issues_gemm1;
586 constexpr auto sr = Policy::template GetSequencer_1<Problem>();
587 static_assert(sr.size() == total_loops);
588
589 constexpr auto c_gst_o_0 = MAKE_SC();
590
591 // compute buffer 1
592 static_for<0, total_loops, 1>{}([&](auto i_issue) {
593 gemm_1(acc_1s[I1], y, ds[I1], i_issue);
594
595 constexpr index_t slot = sr.at(i_issue);
596 if constexpr(slot & GST_O)
597 {
598 auto out = cast_tile<ODataType>(acc_1s[I0]);
599 atomic_add_o(out, number<NEXT_SCI(c_gst_o_0, i_issue)>{});
600 }
601 });
602 {
603 auto out = cast_tile<ODataType>(acc_1s[I1]);
604 atomic_add_o(out, NEG1);
605 }
606 };
607
608 // start of pipeline
609 // clang-format off
610 gld_a(a_sst_win0, NEG1, TRUE);
611 gld_g(gs[I0], NEG1, TRUE);
612 move_a();
613 move_g();
614 clear_tile(acc_0);
615
616 // preload for next round
617 gld_a(a_sst_win1, NEG1);
618 gld_g(gs[I1], NEG1);
619
620 // make sure a,g loaded
621 block_sync_load_raw(issues_a + issues_g);
623
624 // we manually unroll double buffer inside hot loop
625 const index_t iters_0 = (num_blocks_k0 - 2) / 2;
626 index_t i_0 = 0; // (void)i_0; (void)iters_0; (void)pipeline_gemm0;
627 while(i_0++ < iters_0)
628 {
629 pipeline_gemm0();
630 }
631 pipeline_gemm0_tail();
632
633 pipeline_bridge();
634
635 const index_t iters_1 = (num_blocks_n1 - 2) / 2;
636 index_t i_1 = 0; // (void) i_1; (void)iters_1; (void)pipeline_gemm1;
637 pipeline_gemm1_head();
638 while(i_1++ < iters_1)
639 {
640 pipeline_gemm1();
641 }
642 pipeline_gemm1_tail();
643 // clang-format on
644 }
645};
646
647} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_LDS_ADDR
Definition config.hpp:58
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
@ GST_O
Definition fused_moegemm_traits.hpp:48
@ GLD_B
Definition fused_moegemm_traits.hpp:45
@ SLD_A
Definition fused_moegemm_traits.hpp:42
@ GLD_A
Definition fused_moegemm_traits.hpp:44
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void lds_load_fence(index_t cnt=0)
Definition tile/core/arch/amd_buffer_addressing.hpp:820
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
CK_TILE_DEVICE constexpr auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition tile_window_linear.hpp:993
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_DEVICE auto load_tile_raw(T &tile, const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Loads a tile of data using inline assembly.
Definition load_tile.hpp:81
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition load_tile.hpp:133
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt=0)
Definition arch.hpp:121
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
CK_TILE_DEVICE void update_tile_raw(tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition update_tile.hpp:68
#define NEXT_SCI(c_, static_i_)
Definition static_counter.hpp:112
#define MAKE_SC()
Definition static_counter.hpp:104
Definition fused_moegemm_pipeline_flatmm_ex.hpp:23
static constexpr index_t SLD_A
Definition fused_moegemm_pipeline_flatmm_ex.hpp:54
static constexpr index_t kAlignmentA
Definition fused_moegemm_pipeline_flatmm_ex.hpp:49
static constexpr bool PadHiddenSize
Definition fused_moegemm_pipeline_flatmm_ex.hpp:46
typename Problem::DScaleDataType DScaleDataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:36
typename Problem::BlockShape BlockShape
Definition fused_moegemm_pipeline_flatmm_ex.hpp:27
static constexpr index_t kAlignmentO
Definition fused_moegemm_pipeline_flatmm_ex.hpp:52
typename Problem::IndexDataType IndexDataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:39
static constexpr index_t kBlockPerCu
Definition fused_moegemm_pipeline_flatmm_ex.hpp:59
remove_cvref_t< Policy_ > Policy
Definition fused_moegemm_pipeline_flatmm_ex.hpp:25
static constexpr const char * name
Definition fused_moegemm_pipeline_flatmm_ex.hpp:69
static constexpr index_t GLD_B
Definition fused_moegemm_pipeline_flatmm_ex.hpp:56
static constexpr index_t GLD_A
Definition fused_moegemm_pipeline_flatmm_ex.hpp:55
remove_cvref_t< Problem_ > Problem
Definition fused_moegemm_pipeline_flatmm_ex.hpp:24
typename Problem::ADataType ADataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:29
static CK_TILE_HOST_DEVICE auto GetOCoord()
Definition fused_moegemm_pipeline_flatmm_ex.hpp:91
typename Problem::GDataType GDataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:30
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition fused_moegemm_pipeline_flatmm_ex.hpp:77
static constexpr index_t kAlignmentG
Definition fused_moegemm_pipeline_flatmm_ex.hpp:50
typename Problem::DDataType DDataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:31
static CK_TILE_HOST_DEVICE auto GetACoord()
Definition fused_moegemm_pipeline_flatmm_ex.hpp:83
typename Problem::ODataType ODataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:33
typename Problem::GScaleDataType GScaleDataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:35
typename Problem::AccDataType AccDataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:32
static constexpr index_t GST_O
Definition fused_moegemm_pipeline_flatmm_ex.hpp:57
typename Problem::TopkWeightDataType TopkWeightDataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:38
static constexpr index_t kAlignmentD
Definition fused_moegemm_pipeline_flatmm_ex.hpp:51
typename Problem::YDataType YDataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:40
static constexpr bool IsGateOnly
Definition fused_moegemm_pipeline_flatmm_ex.hpp:44
typename Problem::Traits Traits
Definition fused_moegemm_pipeline_flatmm_ex.hpp:42
typename Problem::YSmoothScaleDataType YSmoothScaleDataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:37
static constexpr bool PadIntermediateSize
Definition fused_moegemm_pipeline_flatmm_ex.hpp:47
typename Problem::AScaleDataType AScaleDataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:34
static constexpr bool UseSmoothQuant
Definition fused_moegemm_pipeline_flatmm_ex.hpp:45
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize_A()
Definition fused_moegemm_pipeline_flatmm_ex.hpp:72
CK_TILE_DEVICE auto operator()(const AWindow &a_window_, const GWindow &g_window_, const DWindow &d_window_, OWindow &o_window_, TopkWeightDataType, CK_TILE_LDS_ADDR void *smem, index_t hidden_size, index_t intermediate_size)
Definition fused_moegemm_pipeline_flatmm_ex.hpp:99
static constexpr value_type value
Definition tile/core/numeric/integral_constant.hpp:16
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192