18template <
typename OldLayout>
34 return {0, 1, 4, 2, 3};
40 return {0, 1, 5, 2, 3, 4};
57 return {1, 0, 2, 3, 4};
62 return {1, 0, 2, 3, 4, 5};
68 return {0, 1, 2, 3, 4};
74 return {0, 1, 2, 3, 4, 5};
86 return {0, 1, 4, 2, 3};
92 return {0, 1, 5, 2, 3, 4};
104 return {3, 0, 4, 1, 2};
110 return {4, 0, 5, 1, 2, 3};
114 printf(
"%s\n", __func__);
115 throw std::runtime_error(
"wrong! unsupported layout");
123template <
typename InLayout>
127 std::vector<std::size_t> physical_lengths;
138 throw std::runtime_error(
"wrong! G != 1");
141 physical_lengths = std::vector<std::size_t>{
static_cast<std::size_t
>(param.
G_),
142 static_cast<std::size_t
>(param.
N_),
143 static_cast<std::size_t
>(param.
C_)};
145 physical_lengths.insert(physical_lengths.begin() + 2,
154 physical_lengths = std::vector<std::size_t>{
static_cast<std::size_t
>(param.
N_),
155 static_cast<std::size_t
>(param.
G_),
156 static_cast<std::size_t
>(param.
C_)};
158 physical_lengths.insert(physical_lengths.end(),
166 physical_lengths = std::vector<std::size_t>{
static_cast<std::size_t
>(param.
G_),
167 static_cast<std::size_t
>(param.
N_),
168 static_cast<std::size_t
>(param.
C_)};
170 physical_lengths.insert(physical_lengths.end(),
178 physical_lengths = std::vector<std::size_t>{
static_cast<std::size_t
>(param.
G_),
179 static_cast<std::size_t
>(param.
N_),
180 static_cast<std::size_t
>(param.
C_)};
182 physical_lengths.insert(physical_lengths.begin() + 2,
190 physical_lengths = std::vector<std::size_t>{
static_cast<std::size_t
>(param.
N_),
191 static_cast<std::size_t
>(param.
G_),
192 static_cast<std::size_t
>(param.
C_)};
194 physical_lengths.insert(physical_lengths.begin() + 1,
200 printf(
"%s\n", __func__);
201 printf(
"%s\n", InLayout::name);
202 throw std::runtime_error(
"wrong! unsupported layout");
215template <
typename WeiLayout>
219 std::vector<std::size_t> physical_lengths;
230 throw std::runtime_error(
"wrong! G != 1");
233 physical_lengths = std::vector<std::size_t>{
static_cast<std::size_t
>(param.
G_),
234 static_cast<std::size_t
>(param.
K_),
235 static_cast<std::size_t
>(param.
C_)};
237 physical_lengths.insert(physical_lengths.begin() + 2,
248 throw std::runtime_error(
"wrong! G != 1");
251 physical_lengths = std::vector<std::size_t>{
static_cast<std::size_t
>(param.
K_),
252 static_cast<std::size_t
>(param.
C_)};
254 physical_lengths.insert(physical_lengths.end(),
262 physical_lengths = std::vector<std::size_t>{
static_cast<std::size_t
>(param.
G_),
263 static_cast<std::size_t
>(param.
K_),
264 static_cast<std::size_t
>(param.
C_)};
266 physical_lengths.insert(physical_lengths.end(),
274 physical_lengths = std::vector<std::size_t>{
static_cast<std::size_t
>(param.
G_),
275 static_cast<std::size_t
>(param.
K_),
276 static_cast<std::size_t
>(param.
C_)};
278 physical_lengths.insert(physical_lengths.begin() + 2,
286 physical_lengths = std::vector<std::size_t>{
static_cast<std::size_t
>(param.
K_),
287 static_cast<std::size_t
>(param.
G_),
288 static_cast<std::size_t
>(param.
C_)};
290 physical_lengths.insert(physical_lengths.begin() + 1,
296 printf(
"%s\n", __func__);
297 printf(
"%s\n", WeiLayout::name);
298 throw std::runtime_error(
"wrong! unsupported layout");
310template <
typename OutLayout>
314 std::vector<std::size_t> physical_lengths;
325 throw std::runtime_error(
"wrong! G != 1");
328 physical_lengths = std::vector<std::size_t>{
static_cast<std::size_t
>(param.
G_),
329 static_cast<std::size_t
>(param.
N_),
330 static_cast<std::size_t
>(param.
K_)};
332 physical_lengths.insert(physical_lengths.begin() + 2,
341 physical_lengths = std::vector<std::size_t>{
static_cast<std::size_t
>(param.
G_),
342 static_cast<std::size_t
>(param.
N_),
343 static_cast<std::size_t
>(param.
K_)};
345 physical_lengths.insert(physical_lengths.end(),
354 physical_lengths = std::vector<std::size_t>{
static_cast<std::size_t
>(param.
N_),
355 static_cast<std::size_t
>(param.
G_),
356 static_cast<std::size_t
>(param.
K_)};
358 physical_lengths.insert(physical_lengths.end(),
366 physical_lengths = std::vector<std::size_t>{
static_cast<std::size_t
>(param.
G_),
367 static_cast<std::size_t
>(param.
N_),
368 static_cast<std::size_t
>(param.
K_)};
370 physical_lengths.insert(physical_lengths.begin() + 2,
378 physical_lengths = std::vector<std::size_t>{
static_cast<std::size_t
>(param.
N_),
379 static_cast<std::size_t
>(param.
G_),
380 static_cast<std::size_t
>(param.
K_)};
382 physical_lengths.insert(physical_lengths.begin() + 1,
388 printf(
"%s\n", __func__);
389 printf(
"%s\n", OutLayout::name);
390 throw std::runtime_error(
"wrong! unsupported layout");
HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(const HostTensorDescriptor &a, const New2Old &new2old, const NewLayout &new_layout=NewLayout())
Definition library/utility/host_tensor.hpp:599
Definition library/utility/convolution_host_tensor_descriptor_helper.hpp:16
std::vector< std::size_t > get_layout_transpose_gnchw_to_old()
Definition library/utility/convolution_host_tensor_descriptor_helper.hpp:19
Definition library/utility/convolution_host_tensor_descriptor_helper.hpp:14
HostTensorDescriptor make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck::utils::conv::ConvParam ¶m)
Definition library/utility/convolution_host_tensor_descriptor_helper.hpp:217
HostTensorDescriptor make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck::utils::conv::ConvParam ¶m)
Definition library/utility/convolution_host_tensor_descriptor_helper.hpp:312
HostTensorDescriptor make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck::utils::conv::ConvParam ¶m)
Definition library/utility/convolution_host_tensor_descriptor_helper.hpp:125
Definition library/utility/check_err.hpp:24
constexpr bool is_same_v
Definition type.hpp:283
A descriptor class for host tensors that manages tensor dimensions, strides, and layout.
Definition library/utility/host_tensor.hpp:171
Definition tensor_operation/gpu/device/tensor_layout.hpp:45
Definition library/utility/convolution_parameter.hpp:20
ck::long_index_t C_
Definition library/utility/convolution_parameter.hpp:50
ck::long_index_t num_dim_spatial_
Definition library/utility/convolution_parameter.hpp:46
std::vector< ck::long_index_t > input_spatial_lengths_
Definition library/utility/convolution_parameter.hpp:53
std::vector< ck::long_index_t > output_spatial_lengths_
Definition library/utility/convolution_parameter.hpp:54
ck::long_index_t N_
Definition library/utility/convolution_parameter.hpp:48
ck::long_index_t G_
Definition library/utility/convolution_parameter.hpp:47
ck::long_index_t K_
Definition library/utility/convolution_parameter.hpp:49
std::vector< ck::long_index_t > filter_spatial_lengths_
Definition library/utility/convolution_parameter.hpp:52