diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index dbe057e20d..4cd41ddb30 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -277,9 +277,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ // TODO make this logic more generic for more sub-dword datatype if constexpr(SrcVectorDim != DstVectorDim && - is_same>::value && - is_same>::value && - SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) + ((is_same>::value && + is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || + (is_same>::value && + is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) { // each transpose does // DstScalarPerVector # of src vectors in src_thread_scratch_ diff --git a/include/ck/utility/transpose_vectors.hpp b/include/ck/utility/transpose_vectors.hpp index 866241a947..31f9c02c74 100644 --- a/include/ck/utility/transpose_vectors.hpp +++ b/include/ck/utility/transpose_vectors.hpp @@ -49,7 +49,7 @@ __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t template struct transpose_vectors { - // we got [NY * NX] ammount of S data to be transposed + // we got [NY * NX] amount of S data to be transposed static constexpr index_t s_per_x = NY; static constexpr index_t s_per_y = NX; @@ -83,5 +83,86 @@ struct transpose_vectors } }; +// transpose int8 4x4 +__device__ void transpose_int8_4x4(const int8x4_t& x0, + const int8x4_t& x1, + const int8x4_t& x2, + const int8x4_t& x3, + int8x4_t& y0, + int8x4_t& y1, + int8x4_t& y2, + int8x4_t& y3) +{ + int32_t t0, t1; + int32_t z0, z1, z2, z3; + constexpr int32_t m0 = 0x05010400; + constexpr int32_t m1 = 0x05040100; + constexpr int32_t m2 = 0x07060302; + constexpr int32_t m3 = 0x07030602; + + // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488 + // -- -- -- -- -- -- -- -- - - - - + // index 7 6 5 4 3 2 1 0 33 77 44 88 + // index is reversed because of little endianness (least significant bits first) + // clang-format off + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t0) : "v"(bit_cast(x1)), "v"(bit_cast(x0)), "s"(m0)); + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t1) : "v"(bit_cast(x3)), "v"(bit_cast(x2)), "s"(m0)); + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z0) : "v"(bit_cast(t1)), "v"(bit_cast(t0)), "s"(m1)); + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z1) : "v"(bit_cast(t1)), "v"(bit_cast(t0)), "s"(m2)); + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t0) : "v"(bit_cast(x1)), "v"(bit_cast(x0)), "s"(m3)); + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t1) : "v"(bit_cast(x3)), "v"(bit_cast(x2)), "s"(m3)); + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z2) : "v"(bit_cast(t1)), "v"(bit_cast(t0)), "s"(m1)); + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z3) : "v"(bit_cast(t1)), "v"(bit_cast(t0)), "s"(m2)); + // clang-format on + + y0 = bit_cast(z0); + y1 = bit_cast(z1); + y2 = bit_cast(z2); + y3 = bit_cast(z3); +} + +template +struct transpose_vectors +{ + // we got [NY * NX] amount of S data to be transposed + static constexpr index_t s_per_x = NY; + static constexpr index_t s_per_y = NX; + + using S = int8_t; + using VX = vector_type; + using VY = vector_type; + + __device__ void operator()(const StaticallyIndexedArray& vx_tuple, + StaticallyIndexedArray& vy_tuple) + { + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!"); + + // loop over 4x4 tile and transpose data from vx_tuple into vy_tuple + static_for<0, NY, 4>{}([&](auto iy) { + static_for<0, NX, 4>{}([&](auto ix) { + // reference to 4 int8 data from vx_tuple + const auto& x_s4_0 = vx_tuple[ix].template AsType()[iy / I4]; + const auto& x_s4_1 = vx_tuple[ix + I1].template AsType()[iy / I4]; + const auto& x_s4_2 = vx_tuple[ix + I2].template AsType()[iy / I4]; + const auto& x_s4_3 = vx_tuple[ix + I3].template AsType()[iy / I4]; + + // reference to 4 int8 data from vy_tuple + auto& y_s4_0 = vy_tuple(iy).template AsType()(ix / I4); + auto& y_s4_1 = vy_tuple(iy + I1).template AsType()(ix / I4); + auto& y_s4_2 = vy_tuple(iy + I2).template AsType()(ix / I4); + auto& y_s4_3 = vy_tuple(iy + I3).template AsType()(ix / I4); + + // transpose + transpose_int8_4x4(x_s4_0, x_s4_1, x_s4_2, x_s4_3, y_s4_0, y_s4_1, y_s4_2, y_s4_3); + }); + }); + } +}; + } // namespace ck #endif