diff --git a/include/ck_tile/core/utility/transpose_vectors.hpp b/include/ck_tile/core/utility/transpose_vectors.hpp index f0d7dae706..f24b976b4c 100644 --- a/include/ck_tile/core/utility/transpose_vectors.hpp +++ b/include/ck_tile/core/utility/transpose_vectors.hpp @@ -26,136 +26,193 @@ struct transpose_vectors using VX = array; using VY = array; - CK_TILE_DEVICE void operator()(const thread_buffer& vx_tuple, - thread_buffer& vy_tuple) + struct generic_tag { + }; + struct bytesize2_2x2_tag + { + }; + struct bytesize1_4x4_tag + { + }; + struct bytesize1_2x2_tag + { + }; + + CK_TILE_DEVICE static constexpr void + apply_impl(const thread_buffer& vx_tuple, thread_buffer& vy_tuple, generic_tag) + { + static_for<0, NY, 1>{}([&](auto iy) { + static_for<0, NX, 1>{}([&](auto ix) { vy_tuple(iy)(ix) = vx_tuple[ix][iy]; }); + }); + } + + CK_TILE_DEVICE static constexpr void apply_impl(const thread_buffer& vx_tuple, + thread_buffer& vy_tuple, + bytesize2_2x2_tag) + { + static_assert(sizeof(S) == 2 && NX % 2 == 0 && NY % 2 == 0, "wrong!"); + + constexpr auto I1 = number<1>{}; + constexpr auto I2 = number<2>{}; + using S2 = array; + // loop over 2x2 tiles and transpose data from vx_tuple into vy_tuple + static_for<0, NY, 2>{}([&](auto iy) { + static_for<0, NX, 2>{}([&](auto ix) { + // 2 16bitx2 data from vx_tuple to be transposed + const S2 x_s2_0 = vx_tuple[ix].template get_as(iy / I2); + const S2 x_s2_1 = vx_tuple[ix + I1].template get_as(iy / I2); + + // transpose 2x2 16bit + // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488 + // -- -- -- -- -- -- -- -- - - - - + // index 7 6 5 4 3 2 1 0 33 77 44 88 + // index is reversed because of little endianness (least significant bits first) + const S2 y_s2_0 = bit_cast( + __builtin_amdgcn_perm(bit_cast(x_s2_0), + bit_cast(x_s2_1), + // (A0.B0.C0.D0.A1.B1.C1.D1)[1, 0, 5, 4] = (C1.D1.C0.D0) + 0x01'00'05'04)); + const S2 y_s2_1 = bit_cast( + __builtin_amdgcn_perm(bit_cast(x_s2_0), + bit_cast(x_s2_1), + // (A0.B0.C0.D0.A1.B1.C1.D1)[3, 2, 7, 6] = (A1.B1.A0.B0) + 0x03'02'07'06)); + + // write transposed 2x2 result: + // write (C1.D1.C0.D0) + vy_tuple(iy).set_as(ix / I2, y_s2_0); + // write (A1.B1.A0.B0) + vy_tuple(iy + I1).set_as(ix / I2, y_s2_1); + }); + }); + } + + CK_TILE_DEVICE static constexpr void apply_impl(const thread_buffer& vx_tuple, + thread_buffer& vy_tuple, + bytesize1_4x4_tag) + { + static_assert(sizeof(S) == 1 && NX % 4 == 0 && NY % 4 == 0, "wrong!"); + constexpr auto I1 = number<1>{}; constexpr auto I2 = number<2>{}; constexpr auto I3 = number<3>{}; constexpr auto I4 = number<4>{}; + using S4 = array; + // loop over 4x4 tiles and transpose data from vx_tuple into vy_tuple + static_for<0, NY, 4>{}([&](auto iy) { + static_for<0, NX, 4>{}([&](auto ix) { + // read A0.B0.C0.D0 + const S4 x_s4_0 = vx_tuple[ix].template get_as(iy / I4); + // read A1.B1.C1.D1 + const S4 x_s4_1 = vx_tuple[ix + I1].template get_as(iy / I4); + // read A2.B2.C2.D2 + const S4 x_s4_2 = vx_tuple[ix + I2].template get_as(iy / I4); + // read A3.B3.C3.D3 + const S4 x_s4_3 = vx_tuple[ix + I3].template get_as(iy / I4); - if constexpr(sizeof(S) == 4) - { - static_for<0, NY, 1>{}([&](auto iy) { - static_for<0, NX, 1>{}([&](auto ix) { vy_tuple(iy)(ix) = vx_tuple[ix][iy]; }); + // (A1.B1.C1.D1.A0.B0.C0.D0)[5, 1, 4, 0] = (C1.C0.D1.D0) + uint32_t t_s4_0 = __builtin_amdgcn_perm( + bit_cast(x_s4_1), bit_cast(x_s4_0), 0x05'01'04'00); + // (A3.B3.C3.D3.A2.B2.C2.D2)[5, 1, 4, 0] = (C3.C2.D3.D2) + uint32_t t_s4_1 = __builtin_amdgcn_perm( + bit_cast(x_s4_3), bit_cast(x_s4_2), 0x05'01'04'00); + // (C3.C2.D3.D2.C1.C0.D1.D0)[5, 4, 1, 0] = (D3.D2.D1.D0) + const S4 y_s4_0 = + bit_cast(__builtin_amdgcn_perm(t_s4_1, t_s4_0, 0x05'04'01'00)); + // (C3.C2.D3.D2.C1.C0.D1.D0)[7, 6, 3, 2] = (C3.C2.C1.C0) + const S4 y_s4_1 = + bit_cast(__builtin_amdgcn_perm(t_s4_1, t_s4_0, 0x07'06'03'02)); + // (A1.B1.C1.D1.A0.B0.C0.D0)[7, 3, 6, 2] = (A1.A0.B1.B0) + t_s4_0 = __builtin_amdgcn_perm( + bit_cast(x_s4_1), bit_cast(x_s4_0), 0x07'03'06'02); + // (A3.B3.C3.D3.A2.B2.C2.D2)[7, 3, 6, 2] = (A3.A2.B3.B2) + t_s4_1 = __builtin_amdgcn_perm( + bit_cast(x_s4_3), bit_cast(x_s4_2), 0x07'03'06'02); + // (A3.A2.B3.B2.A1.A0.B1.B0)[5, 4, 1, 0] = (B3.B2.B1.B0) + const S4 y_s4_2 = + bit_cast(__builtin_amdgcn_perm(t_s4_1, t_s4_0, 0x05'04'01'00)); + // (A3.A2.B3.B2.A1.A0.B1.B0)[7, 6, 3, 2] = (A3.A2.A1.A0) + const S4 y_s4_3 = + bit_cast(__builtin_amdgcn_perm(t_s4_1, t_s4_0, 0x07'06'03'02)); + + // write transposed 4x4 result: + // write (D3.D2.D1.D0) + vy_tuple(iy).set_as(ix / I4, y_s4_0); + // write (C3.C2.C1.C0) + vy_tuple(iy + I1).set_as(ix / I4, y_s4_1); + // write (B3.B2.B1.B0) + vy_tuple(iy + I2).set_as(ix / I4, y_s4_2); + // write (A3.A2.A1.A0) + vy_tuple(iy + I3).set_as(ix / I4, y_s4_3); }); - } - else if constexpr(sizeof(S) == 2) - { - static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!"); + }); + } - using S2 = array; // typename array::type; + CK_TILE_DEVICE static constexpr void apply_impl(const thread_buffer& vx_tuple, + thread_buffer& vy_tuple, + bytesize1_2x2_tag) + { + static_assert(sizeof(S) == 1 && NX % 2 == 0 && NY % 2 == 0, "wrong!"); - // loop over 2x2 tile and transpose data from vx_tuple into vy_tuple - static_for<0, NY, 2>{}([&](auto iy) { - static_for<0, NX, 2>{}([&](auto ix) { - // 2 16bitx2 data from vx_tuple to be transposed - const int32_t x_s2_0 = - bit_cast(vx_tuple[ix].template get_as()[iy / I2]); - const int32_t x_s2_1 = - bit_cast(vx_tuple[ix + I1].template get_as()[iy / I2]); + constexpr auto I1 = number<1>{}; + constexpr auto I2 = number<2>{}; + using S2 = array; + // loop over 2x2 tiles and transpose data from vx_tuple into vy_tuple + static_for<0, NY, 2>{}([&](auto iy) { + static_for<0, NX, 2>{}([&](auto ix) { + // read A0.B0 + const S2 x_s2_0 = vx_tuple[ix].template get_as(iy / I2); + // read A1.B1 + const S2 x_s2_1 = vx_tuple[ix + I1].template get_as(iy / I2); - constexpr int32_t m0 = 0x05040100; - constexpr int32_t m1 = 0x07060302; + // v_perm_b32: pick 4 bytes from 8 bytes in (input0.input1) using the mask + const S2 y_s2_0 = bit_cast(static_cast(__builtin_amdgcn_perm( + static_cast(bit_cast(x_s2_0)), + static_cast(bit_cast(x_s2_1)), + // (XX.XX.A0.B0.XX.XX.A1.B1)[clear, clear, 0, 4] = (00.00.B1.B0) + 0x0C'0C'00'04))); - // transpose 2x2 16bit - // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488 - // -- -- -- -- -- -- -- -- - - - - - // index 7 6 5 4 3 2 1 0 33 77 44 88 - // index is reversed because of little endianness (least significant bits first) - const int32_t y_s2_0 = __builtin_amdgcn_perm(x_s2_1, x_s2_0, m0); - const int32_t y_s2_1 = __builtin_amdgcn_perm(x_s2_1, x_s2_0, m1); + const S2 y_s2_1 = bit_cast(static_cast(__builtin_amdgcn_perm( + static_cast(bit_cast(x_s2_0)), + static_cast(bit_cast(x_s2_1)), + // (XX.XX.A0.B0.XX.XX.A1.B1)[clear, clear, 1, 5] = (00.00.A1.A0) + 0x0C'0C'01'05))); - // 2 16bitx2 data after transposed - vy_tuple(iy).template get_as()(ix / I2) = bit_cast(y_s2_0); - vy_tuple(iy + I1).template get_as()(ix / I2) = bit_cast(y_s2_1); - }); + // write transposed 2x2 result: + // write (B1.B0) + vy_tuple(iy).set_as(ix / I2, y_s2_0); + // write (A1.A0) + vy_tuple(iy + I1).set_as(ix / I2, y_s2_1); }); - } - else if constexpr(sizeof(S) == 1) + }); + } + + CK_TILE_DEVICE static constexpr auto tag_dispatch() + { + if constexpr(sizeof(S) == 2 && NX % 2 == 0 && NY % 2 == 0) { - static_assert(((NX % 4 == 0 && NY % 4 == 0) || (NX % 2 == 0 && NY % 2 == 0)), "wrong!"); - - using S4 = array; // typename array::type; - using S2 = array; // typename array::type; - - if constexpr(NX % 4 == 0 && NY % 4 == 0) - { - // loop over 4x4 tile and transpose data from vx_tuple into vy_tuple - static_for<0, NY, 4>{}([&](auto iy) { - static_for<0, NX, 4>{}([&](auto ix) { - // 4 int8x4 data from vx_tuple - const int32_t x_s4_0 = - bit_cast(vx_tuple[ix].template get_as()[iy / I4]); - const int32_t x_s4_1 = - bit_cast(vx_tuple[ix + I1].template get_as()[iy / I4]); - const int32_t x_s4_2 = - bit_cast(vx_tuple[ix + I2].template get_as()[iy / I4]); - const int32_t x_s4_3 = - bit_cast(vx_tuple[ix + I3].template get_as()[iy / I4]); - - // transpose - int32_t t_s4_0, t_s4_1; - int32_t y_s4_0, y_s4_1, y_s4_2, y_s4_3; - - constexpr int32_t m0 = 0x05010400; - constexpr int32_t m1 = 0x05040100; - constexpr int32_t m2 = 0x07060302; - constexpr int32_t m3 = 0x07030602; - - // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> - // 0x33774488 - // -- -- -- -- -- -- -- -- - - - - - // index 7 6 5 4 3 2 1 0 33 77 44 88 - // index is reversed because of little endianness (least significant bits - // first) - t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m0); - t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m0); - y_s4_0 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1); - y_s4_1 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2); - t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m3); - t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m3); - y_s4_2 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1); - y_s4_3 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2); - - // 4 int8x4 data from vy_tuple - vy_tuple(iy).template get_as()(ix / I4) = bit_cast(y_s4_0); - vy_tuple(iy + I1).template get_as()(ix / I4) = bit_cast(y_s4_1); - vy_tuple(iy + I2).template get_as()(ix / I4) = bit_cast(y_s4_2); - vy_tuple(iy + I3).template get_as()(ix / I4) = bit_cast(y_s4_3); - }); - }); - } - else if constexpr(NX % 2 == 0 && NY % 2 == 0) - { - static_for<0, NY, 2>{}([&](auto ix) { - static_for<0, NX, 2>{}([&](auto iy) { - const int16_t x_s2_0 = - bit_cast(vx_tuple[ix].template get_as()[iy / I2]); - const int16_t x_s2_1 = - bit_cast(vx_tuple[ix + I1].template get_as()[iy / I2]); - constexpr int32_t m0 = 0x05040100; - constexpr int32_t m1 = 0x07060302; - - const int32_t x0_32 = static_cast(x_s2_0 & 0xFFFF); - const int32_t x1_32 = static_cast(x_s2_1 & 0xFFFF); - - const int32_t y_s2_0 = __builtin_amdgcn_perm(x1_32, x0_32, m0); - const int32_t y_s2_1 = __builtin_amdgcn_perm(x1_32, x0_32, m1); - - vy_tuple(iy).template get_as()[ix / I2] = - bit_cast(static_cast(y_s2_0 & 0xFFFF)); - vy_tuple(iy + I1).template get_as()[ix / I2] = - bit_cast(static_cast(y_s2_1 & 0xFFFF)); - }); - }); - } + return bytesize2_2x2_tag{}; + } + else if constexpr(sizeof(S) == 1 && NX % 4 == 0 && NY % 4 == 0) + { + return bytesize1_4x4_tag{}; + } + else if constexpr(sizeof(S) == 1 && NX % 2 == 0 && NY % 2 == 0) + { + return bytesize1_2x2_tag{}; } else { - static_assert(false, "not implemented"); + return generic_tag{}; } } + + CK_TILE_DEVICE void operator()(const thread_buffer& vx_tuple, + thread_buffer& vy_tuple) const + { + apply_impl(vx_tuple, vy_tuple, tag_dispatch()); + } }; } // namespace ck_tile diff --git a/test/ck_tile/batched_transpose/test_batched_transpose.cpp b/test/ck_tile/batched_transpose/test_batched_transpose.cpp index 8812397946..71a133a4b6 100644 --- a/test/ck_tile/batched_transpose/test_batched_transpose.cpp +++ b/test/ck_tile/batched_transpose/test_batched_transpose.cpp @@ -306,6 +306,12 @@ class CaseHalfPadRectTile2LoadTranspose { }; +class CaseBytePadRectTile + : public TestCkTileBatchedTranspose< + PipelineConfig> +{ +}; + TEST_P(CaseHalf, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseByte, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseWord, TestCorrectness) { this->Run(GetParam()); } @@ -321,6 +327,7 @@ TEST_P(CaseHalfPadRectTile1, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseHalfPadRectTile1LoadTranspose, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseHalfPadRectTile2, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseHalfPadRectTile2LoadTranspose, TestCorrectness) { this->Run(GetParam()); } +TEST_P(CaseBytePadRectTile, TestCorrectness) { this->Run(GetParam()); } // clang-format off INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalf, kTestingValues); @@ -338,5 +345,6 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadRectTile1, INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadRectTile1LoadTranspose, kTestingValues); INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadRectTile2, kTestingValues); INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadRectTile2LoadTranspose, kTestingValues); +INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseBytePadRectTile, kTestingValues); // clang-format on