28 vy0.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I0];
29 vy0.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I0];
31 vy1.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I1];
32 vy1.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I1];
34 y0 = vy0.template AsType<half2_t>()[I0];
35 y1 = vy1.template AsType<half2_t>()[I0];
37 constexpr int32_t m0 = 0x05040100;
38 constexpr int32_t m1 = 0x07060302;
49template <index_t NX, index_t NY>
66 static_assert((NX % 2 == 0 && NY % 2 == 0),
"wrong!");
72 const auto& x_s2_0 = vx_tuple[ix].template AsType<half2_t>()[iy / I2];
73 const auto& x_s2_1 = vx_tuple[ix + I1].template AsType<half2_t>()[iy / I2];
76 auto& y_s2_0 = vy_tuple(iy).template AsType<half2_t>()(ix / I2);
77 auto& y_s2_1 = vy_tuple(iy + I1).template AsType<half2_t>()(ix / I2);
98 constexpr int32_t m0 = 0x05010400;
99 constexpr int32_t m1 = 0x05040100;
100 constexpr int32_t m2 = 0x07060302;
101 constexpr int32_t m3 = 0x07030602;
122template <index_t NX, index_t NY>
141 static_assert((NX % 4 == 0 && NY % 4 == 0),
"wrong!");
147 const auto& x_s4_0 = vx_tuple[ix].template AsType<int8x4_t>()[iy / I4];
148 const auto& x_s4_1 = vx_tuple[ix + I1].template AsType<int8x4_t>()[iy / I4];
149 const auto& x_s4_2 = vx_tuple[ix + I2].template AsType<int8x4_t>()[iy / I4];
150 const auto& x_s4_3 = vx_tuple[ix + I3].template AsType<int8x4_t>()[iy / I4];
153 auto& y_s4_0 = vy_tuple(iy).template AsType<int8x4_t>()(ix / I4);
154 auto& y_s4_1 = vy_tuple(iy + I1).template AsType<int8x4_t>()(ix / I4);
155 auto& y_s4_2 = vy_tuple(iy + I2).template AsType<int8x4_t>()(ix / I4);
156 auto& y_s4_3 = vy_tuple(iy + I3).template AsType<int8x4_t>()(ix / I4);
177 constexpr int32_t m0 = 0x05010400;
178 constexpr int32_t m1 = 0x05040100;
179 constexpr int32_t m2 = 0x07060302;
180 constexpr int32_t m3 = 0x07030602;
201template <index_t NX, index_t NY>
220 static_assert((NX % 4 == 0 && NY % 4 == 0),
"wrong!");
226 const auto& x_s4_0 = vx_tuple[ix].template AsType<f8x4_t>()[iy / I4];
227 const auto& x_s4_1 = vx_tuple[ix + I1].template AsType<f8x4_t>()[iy / I4];
228 const auto& x_s4_2 = vx_tuple[ix + I2].template AsType<f8x4_t>()[iy / I4];
229 const auto& x_s4_3 = vx_tuple[ix + I3].template AsType<f8x4_t>()[iy / I4];
232 auto& y_s4_0 = vy_tuple(iy).template AsType<f8x4_t>()(ix / I4);
233 auto& y_s4_1 = vy_tuple(iy + I1).template AsType<f8x4_t>()(ix / I4);
234 auto& y_s4_2 = vy_tuple(iy + I2).template AsType<f8x4_t>()(ix / I4);
235 auto& y_s4_3 = vy_tuple(iy + I3).template AsType<f8x4_t>()(ix / I4);
238 transpose_f8_4x4(x_s4_0, x_s4_1, x_s4_2, x_s4_3, y_s4_0, y_s4_1, y_s4_2, y_s4_3);
f8_fnuz_t f8_t
Definition amd_ck_fp8.hpp:1762
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
typename vector_type< int8_t, 4 >::type int8x4_t
Definition dtype_vector.hpp:2177
_Float16 half_t
Definition data_type.hpp:31
integral_constant< index_t, N > Number
Definition number.hpp:12
__device__ void transpose_f8_4x4(const f8x4_t &x0, const f8x4_t &x1, const f8x4_t &x2, const f8x4_t &x3, f8x4_t &y0, f8x4_t &y1, f8x4_t &y2, f8x4_t &y3)
Definition utility/transpose_vectors.hpp:166
std::enable_if< B, T > enable_if
Definition enable_if.hpp:24
typename vector_type< half_t, 2 >::type half2_t
Definition dtype_vector.hpp:2153
__device__ void transpose_int8_4x4(const int8x4_t &x0, const int8x4_t &x1, const int8x4_t &x2, const int8x4_t &x3, int8x4_t &y0, int8x4_t &y1, int8x4_t &y2, int8x4_t &y3)
Definition utility/transpose_vectors.hpp:87
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
__device__ void transpose_fp16_2x2(const half2_t &x0, const half2_t &x1, half2_t &y0, half2_t &y1)
Definition utility/transpose_vectors.hpp:19
signed int int32_t
Definition stdint.h:123
signed char int8_t
Definition stdint.h:121
Definition functional2.hpp:33
vector_type< f8_t, s_per_x > VX
Definition utility/transpose_vectors.hpp:209
f8_t S
Definition utility/transpose_vectors.hpp:208
__device__ void operator()(const StaticallyIndexedArray< const VX &, NX > &vx_tuple, StaticallyIndexedArray< VY &, NY > &vy_tuple)
Definition utility/transpose_vectors.hpp:212
vector_type< f8_t, s_per_y > VY
Definition utility/transpose_vectors.hpp:210
static constexpr index_t s_per_x
Definition utility/transpose_vectors.hpp:205
static constexpr index_t s_per_y
Definition utility/transpose_vectors.hpp:206
vector_type< half_t, s_per_x > VX
Definition utility/transpose_vectors.hpp:57
static constexpr index_t s_per_x
Definition utility/transpose_vectors.hpp:53
static constexpr index_t s_per_y
Definition utility/transpose_vectors.hpp:54
half_t S
Definition utility/transpose_vectors.hpp:56
vector_type< half_t, s_per_y > VY
Definition utility/transpose_vectors.hpp:58
__device__ void operator()(const StaticallyIndexedArray< const VX &, NX > &vx_tuple, StaticallyIndexedArray< VY &, NY > &vy_tuple)
Definition utility/transpose_vectors.hpp:60
__device__ void operator()(const StaticallyIndexedArray< const VX &, NX > &vx_tuple, StaticallyIndexedArray< VY &, NY > &vy_tuple)
Definition utility/transpose_vectors.hpp:133
vector_type< int8_t, s_per_y > VY
Definition utility/transpose_vectors.hpp:131
static constexpr index_t s_per_x
Definition utility/transpose_vectors.hpp:126
int8_t S
Definition utility/transpose_vectors.hpp:129
static constexpr index_t s_per_y
Definition utility/transpose_vectors.hpp:127
vector_type< int8_t, s_per_x > VX
Definition utility/transpose_vectors.hpp:130
Definition utility/transpose_vectors.hpp:16
Definition dtype_vector.hpp:10