device_gemm_v2.hpp Source File

device_gemm_v2.hpp Source File#

Composable Kernel: device_gemm_v2.hpp Source File
device_gemm_v2.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9namespace tensor_operation {
10namespace device {
11
12template <typename ALayout,
13 typename BLayout,
14 typename CLayout,
15 typename ADataType,
16 typename BDataType,
17 typename CDataType,
18 typename AElementwiseOperation,
19 typename BElementwiseOperation,
20 typename CElementwiseOperation>
22{
23 virtual std::unique_ptr<BaseArgument>
24 MakeArgumentPointer(const void* p_a,
25 const void* p_b,
26 void* p_c,
30 ck::index_t StrideA,
31 ck::index_t StrideB,
32 ck::index_t StrideC,
33 ck::index_t KSplit,
34 AElementwiseOperation a_element_op,
35 BElementwiseOperation b_element_op,
36 CElementwiseOperation c_element_op) = 0;
37
38 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
39
40 virtual bool GetPermuteA() = 0;
41 virtual bool GetPermuteB() = 0;
43};
44
45template <typename ALayout,
46 typename BLayout,
47 typename DsLayout,
48 typename CLayout,
49 typename ADataType,
50 typename BDataType,
51 typename DsDataType,
52 typename CDataType,
53 typename AElementwiseOperation,
54 typename BElementwiseOperation,
55 typename CElementwiseOperation>
57{
58 static constexpr index_t NumDTensor = DsDataType::Size();
59
60 virtual std::unique_ptr<BaseArgument>
61 MakeArgumentPointer(const void* p_a,
62 const void* p_b,
63 std::array<const void*, NumDTensor> p_ds,
64 void* p_c,
68 ck::index_t StrideA,
69 ck::index_t StrideB,
70 std::array<ck::index_t, NumDTensor> DsStrides,
71 ck::index_t StrideC,
72 ck::index_t KSplit,
73 AElementwiseOperation a_element_op,
74 BElementwiseOperation b_element_op,
75 CElementwiseOperation c_element_op) = 0;
76
77 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
78};
79
80template <typename ALayout,
81 typename BLayout,
82 typename CLayout,
83 typename ADataType,
84 typename BDataType,
85 typename BScaleType,
86 typename CDataType,
87 index_t ScaleBlockN,
88 index_t ScaleBlockK,
89 typename AElementwiseOperation,
90 typename BElementwiseOperation,
91 typename CElementwiseOperation>
93{
94 virtual std::unique_ptr<BaseArgument>
95 MakeArgumentPointer(const void* p_a,
96 const void* p_b,
97 void* p_c,
100 ck::index_t K,
101 ck::index_t StrideA,
102 ck::index_t StrideB,
103 ck::index_t StrideC,
104 ck::index_t StrideScaleB,
105 const void* p_b_scale,
106 ck::index_t KSplit,
107 AElementwiseOperation a_element_op,
108 BElementwiseOperation b_element_op,
109 CElementwiseOperation c_element_op) = 0;
110
111 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
112
113 virtual bool GetPermuteB() = 0;
115};
116
117template <typename ALayout,
118 typename BLayout,
119 typename CLayout,
120 typename ADataType,
121 typename BDataType,
122 typename CDataType,
123 typename AElementwiseOperation,
124 typename BElementwiseOperation,
125 typename CElementwiseOperation>
127{
128 virtual std::unique_ptr<BaseArgument>
129 MakeArgumentPointer(const void* p_a,
130 const void* p_b,
131 void* p_c,
132 ck::index_t M,
133 ck::index_t N,
134 ck::index_t K,
135 ck::index_t StrideA,
136 ck::index_t StrideB,
137 ck::index_t StrideC,
138 ck::index_t KSplit,
139 AElementwiseOperation a_element_op,
140 BElementwiseOperation b_element_op,
141 CElementwiseOperation c_element_op) = 0;
142
143 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
144
145 virtual bool GetPermuteA() = 0;
146 virtual bool GetPermuteB() = 0;
148 virtual int GetPreShuffleParameters() = 0;
149};
150
151} // namespace device
152} // namespace tensor_operation
153} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, ck::index_t KSplit, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
Definition device_gemm_v2.hpp:93
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, ck::index_t StrideScaleB, const void *p_b_scale, ck::index_t KSplit, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
Definition device_gemm_v2.hpp:22
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, ck::index_t KSplit, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
Definition device_gemm_v2.hpp:57
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumDTensor
Definition device_gemm_v2.hpp:58
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, std::array< ck::index_t, NumDTensor > DsStrides, ck::index_t StrideC, ck::index_t KSplit, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0