device_cgemm.hpp Source File

device_cgemm.hpp Source File#

Composable Kernel: device_cgemm.hpp Source File
device_cgemm.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5#include "device_base.hpp"
6
7namespace ck {
8namespace tensor_operation {
9namespace device {
10
11template <typename AElementwiseOperation,
12 typename BElementwiseOperation,
13 typename CElementwiseOperation>
15{
16 virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a_real,
17 const void* p_a_imag,
18 const void* p_b_real,
19 const void* p_b_imag,
20 void* p_c_real,
21 void* p_c_imag,
22 void* p_workspace,
26 ck::index_t StrideA,
27 ck::index_t StrideB,
28 ck::index_t StrideC,
29 AElementwiseOperation a_element_op,
30 BElementwiseOperation b_element_op,
31 CElementwiseOperation c_element_op,
32 ck::index_t KBatch = 1) = 0;
33
34 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
35 virtual std::size_t GetWorkspaceSize(index_t MRaw,
36 index_t NRaw,
37 index_t KRaw,
38 index_t StrideA,
39 index_t StrideB,
40 index_t StrideC) const = 0;
41};
42
43template <typename AElementwiseOperation,
44 typename BElementwiseOperation,
45 typename CElementwiseOperation>
46using DeviceCGemmPtr = std::unique_ptr<
48
49} // namespace device
50} // namespace tensor_operation
51} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
std::unique_ptr< DeviceCGemm< AElementwiseOperation, BElementwiseOperation, CElementwiseOperation > > DeviceCGemmPtr
Definition device_cgemm.hpp:46
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_cgemm.hpp:15
virtual std::size_t GetWorkspaceSize(index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideC) const =0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a_real, const void *p_a_imag, const void *p_b_real, const void *p_b_imag, void *p_c_real, void *p_c_imag, void *p_workspace, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, ck::index_t KBatch=1)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0