[CK_TILE] Fix transpose_vectors for 2x2 8-bit tiles (#3042)

fix transpose_vectors logic for 2x2 8-bit tiles

    add a test which goes through this code path.

    factor out constexpr'd cases into smaller functions.

    add inline docs about the data movement

    impact: gemms with 8-bit non-rcr inputs on gfx942
This commit is contained in:
Max Podkorytov
2025-10-20 13:40:44 -07:00
committed by GitHub
parent 9f77061094
commit 2570462ecf
2 changed files with 176 additions and 111 deletions

View File

@@ -26,136 +26,193 @@ struct transpose_vectors
using VX = array<S, s_per_x>;
using VY = array<S, s_per_y>;
CK_TILE_DEVICE void operator()(const thread_buffer<VX, NX>& vx_tuple,
thread_buffer<VY, NY>& vy_tuple)
struct generic_tag
{
};
struct bytesize2_2x2_tag
{
};
struct bytesize1_4x4_tag
{
};
struct bytesize1_2x2_tag
{
};
CK_TILE_DEVICE static constexpr void
apply_impl(const thread_buffer<VX, NX>& vx_tuple, thread_buffer<VY, NY>& vy_tuple, generic_tag)
{
static_for<0, NY, 1>{}([&](auto iy) {
static_for<0, NX, 1>{}([&](auto ix) { vy_tuple(iy)(ix) = vx_tuple[ix][iy]; });
});
}
CK_TILE_DEVICE static constexpr void apply_impl(const thread_buffer<VX, NX>& vx_tuple,
thread_buffer<VY, NY>& vy_tuple,
bytesize2_2x2_tag)
{
static_assert(sizeof(S) == 2 && NX % 2 == 0 && NY % 2 == 0, "wrong!");
constexpr auto I1 = number<1>{};
constexpr auto I2 = number<2>{};
using S2 = array<S, 2>;
// loop over 2x2 tiles and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 2>{}([&](auto iy) {
static_for<0, NX, 2>{}([&](auto ix) {
// 2 16bitx2 data from vx_tuple to be transposed
const S2 x_s2_0 = vx_tuple[ix].template get_as<S2>(iy / I2);
const S2 x_s2_1 = vx_tuple[ix + I1].template get_as<S2>(iy / I2);
// transpose 2x2 16bit
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
const S2 y_s2_0 = bit_cast<S2>(
__builtin_amdgcn_perm(bit_cast<uint32_t>(x_s2_0),
bit_cast<uint32_t>(x_s2_1),
// (A0.B0.C0.D0.A1.B1.C1.D1)[1, 0, 5, 4] = (C1.D1.C0.D0)
0x01'00'05'04));
const S2 y_s2_1 = bit_cast<S2>(
__builtin_amdgcn_perm(bit_cast<uint32_t>(x_s2_0),
bit_cast<uint32_t>(x_s2_1),
// (A0.B0.C0.D0.A1.B1.C1.D1)[3, 2, 7, 6] = (A1.B1.A0.B0)
0x03'02'07'06));
// write transposed 2x2 result:
// write (C1.D1.C0.D0)
vy_tuple(iy).set_as(ix / I2, y_s2_0);
// write (A1.B1.A0.B0)
vy_tuple(iy + I1).set_as(ix / I2, y_s2_1);
});
});
}
CK_TILE_DEVICE static constexpr void apply_impl(const thread_buffer<VX, NX>& vx_tuple,
thread_buffer<VY, NY>& vy_tuple,
bytesize1_4x4_tag)
{
static_assert(sizeof(S) == 1 && NX % 4 == 0 && NY % 4 == 0, "wrong!");
constexpr auto I1 = number<1>{};
constexpr auto I2 = number<2>{};
constexpr auto I3 = number<3>{};
constexpr auto I4 = number<4>{};
using S4 = array<S, 4>;
// loop over 4x4 tiles and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 4>{}([&](auto iy) {
static_for<0, NX, 4>{}([&](auto ix) {
// read A0.B0.C0.D0
const S4 x_s4_0 = vx_tuple[ix].template get_as<S4>(iy / I4);
// read A1.B1.C1.D1
const S4 x_s4_1 = vx_tuple[ix + I1].template get_as<S4>(iy / I4);
// read A2.B2.C2.D2
const S4 x_s4_2 = vx_tuple[ix + I2].template get_as<S4>(iy / I4);
// read A3.B3.C3.D3
const S4 x_s4_3 = vx_tuple[ix + I3].template get_as<S4>(iy / I4);
if constexpr(sizeof(S) == 4)
{
static_for<0, NY, 1>{}([&](auto iy) {
static_for<0, NX, 1>{}([&](auto ix) { vy_tuple(iy)(ix) = vx_tuple[ix][iy]; });
// (A1.B1.C1.D1.A0.B0.C0.D0)[5, 1, 4, 0] = (C1.C0.D1.D0)
uint32_t t_s4_0 = __builtin_amdgcn_perm(
bit_cast<uint32_t>(x_s4_1), bit_cast<uint32_t>(x_s4_0), 0x05'01'04'00);
// (A3.B3.C3.D3.A2.B2.C2.D2)[5, 1, 4, 0] = (C3.C2.D3.D2)
uint32_t t_s4_1 = __builtin_amdgcn_perm(
bit_cast<uint32_t>(x_s4_3), bit_cast<uint32_t>(x_s4_2), 0x05'01'04'00);
// (C3.C2.D3.D2.C1.C0.D1.D0)[5, 4, 1, 0] = (D3.D2.D1.D0)
const S4 y_s4_0 =
bit_cast<S4>(__builtin_amdgcn_perm(t_s4_1, t_s4_0, 0x05'04'01'00));
// (C3.C2.D3.D2.C1.C0.D1.D0)[7, 6, 3, 2] = (C3.C2.C1.C0)
const S4 y_s4_1 =
bit_cast<S4>(__builtin_amdgcn_perm(t_s4_1, t_s4_0, 0x07'06'03'02));
// (A1.B1.C1.D1.A0.B0.C0.D0)[7, 3, 6, 2] = (A1.A0.B1.B0)
t_s4_0 = __builtin_amdgcn_perm(
bit_cast<uint32_t>(x_s4_1), bit_cast<uint32_t>(x_s4_0), 0x07'03'06'02);
// (A3.B3.C3.D3.A2.B2.C2.D2)[7, 3, 6, 2] = (A3.A2.B3.B2)
t_s4_1 = __builtin_amdgcn_perm(
bit_cast<uint32_t>(x_s4_3), bit_cast<uint32_t>(x_s4_2), 0x07'03'06'02);
// (A3.A2.B3.B2.A1.A0.B1.B0)[5, 4, 1, 0] = (B3.B2.B1.B0)
const S4 y_s4_2 =
bit_cast<S4>(__builtin_amdgcn_perm(t_s4_1, t_s4_0, 0x05'04'01'00));
// (A3.A2.B3.B2.A1.A0.B1.B0)[7, 6, 3, 2] = (A3.A2.A1.A0)
const S4 y_s4_3 =
bit_cast<S4>(__builtin_amdgcn_perm(t_s4_1, t_s4_0, 0x07'06'03'02));
// write transposed 4x4 result:
// write (D3.D2.D1.D0)
vy_tuple(iy).set_as(ix / I4, y_s4_0);
// write (C3.C2.C1.C0)
vy_tuple(iy + I1).set_as(ix / I4, y_s4_1);
// write (B3.B2.B1.B0)
vy_tuple(iy + I2).set_as(ix / I4, y_s4_2);
// write (A3.A2.A1.A0)
vy_tuple(iy + I3).set_as(ix / I4, y_s4_3);
});
}
else if constexpr(sizeof(S) == 2)
{
static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!");
});
}
using S2 = array<S, 2>; // typename array<S, 2>::type;
CK_TILE_DEVICE static constexpr void apply_impl(const thread_buffer<VX, NX>& vx_tuple,
thread_buffer<VY, NY>& vy_tuple,
bytesize1_2x2_tag)
{
static_assert(sizeof(S) == 1 && NX % 2 == 0 && NY % 2 == 0, "wrong!");
// loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 2>{}([&](auto iy) {
static_for<0, NX, 2>{}([&](auto ix) {
// 2 16bitx2 data from vx_tuple to be transposed
const int32_t x_s2_0 =
bit_cast<int32_t>(vx_tuple[ix].template get_as<S2>()[iy / I2]);
const int32_t x_s2_1 =
bit_cast<int32_t>(vx_tuple[ix + I1].template get_as<S2>()[iy / I2]);
constexpr auto I1 = number<1>{};
constexpr auto I2 = number<2>{};
using S2 = array<S, 2>;
// loop over 2x2 tiles and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 2>{}([&](auto iy) {
static_for<0, NX, 2>{}([&](auto ix) {
// read A0.B0
const S2 x_s2_0 = vx_tuple[ix].template get_as<S2>(iy / I2);
// read A1.B1
const S2 x_s2_1 = vx_tuple[ix + I1].template get_as<S2>(iy / I2);
constexpr int32_t m0 = 0x05040100;
constexpr int32_t m1 = 0x07060302;
// v_perm_b32: pick 4 bytes from 8 bytes in (input0.input1) using the mask
const S2 y_s2_0 = bit_cast<S2>(static_cast<uint16_t>(__builtin_amdgcn_perm(
static_cast<uint32_t>(bit_cast<uint16_t>(x_s2_0)),
static_cast<uint32_t>(bit_cast<uint16_t>(x_s2_1)),
// (XX.XX.A0.B0.XX.XX.A1.B1)[clear, clear, 0, 4] = (00.00.B1.B0)
0x0C'0C'00'04)));
// transpose 2x2 16bit
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
const int32_t y_s2_0 = __builtin_amdgcn_perm(x_s2_1, x_s2_0, m0);
const int32_t y_s2_1 = __builtin_amdgcn_perm(x_s2_1, x_s2_0, m1);
const S2 y_s2_1 = bit_cast<S2>(static_cast<uint16_t>(__builtin_amdgcn_perm(
static_cast<uint32_t>(bit_cast<uint16_t>(x_s2_0)),
static_cast<uint32_t>(bit_cast<uint16_t>(x_s2_1)),
// (XX.XX.A0.B0.XX.XX.A1.B1)[clear, clear, 1, 5] = (00.00.A1.A0)
0x0C'0C'01'05)));
// 2 16bitx2 data after transposed
vy_tuple(iy).template get_as<S2>()(ix / I2) = bit_cast<S2>(y_s2_0);
vy_tuple(iy + I1).template get_as<S2>()(ix / I2) = bit_cast<S2>(y_s2_1);
});
// write transposed 2x2 result:
// write (B1.B0)
vy_tuple(iy).set_as(ix / I2, y_s2_0);
// write (A1.A0)
vy_tuple(iy + I1).set_as(ix / I2, y_s2_1);
});
}
else if constexpr(sizeof(S) == 1)
});
}
CK_TILE_DEVICE static constexpr auto tag_dispatch()
{
if constexpr(sizeof(S) == 2 && NX % 2 == 0 && NY % 2 == 0)
{
static_assert(((NX % 4 == 0 && NY % 4 == 0) || (NX % 2 == 0 && NY % 2 == 0)), "wrong!");
using S4 = array<S, 4>; // typename array<S, 4>::type;
using S2 = array<S, 2>; // typename array<S, 4>::type;
if constexpr(NX % 4 == 0 && NY % 4 == 0)
{
// loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 4>{}([&](auto iy) {
static_for<0, NX, 4>{}([&](auto ix) {
// 4 int8x4 data from vx_tuple
const int32_t x_s4_0 =
bit_cast<int32_t>(vx_tuple[ix].template get_as<S4>()[iy / I4]);
const int32_t x_s4_1 =
bit_cast<int32_t>(vx_tuple[ix + I1].template get_as<S4>()[iy / I4]);
const int32_t x_s4_2 =
bit_cast<int32_t>(vx_tuple[ix + I2].template get_as<S4>()[iy / I4]);
const int32_t x_s4_3 =
bit_cast<int32_t>(vx_tuple[ix + I3].template get_as<S4>()[iy / I4]);
// transpose
int32_t t_s4_0, t_s4_1;
int32_t y_s4_0, y_s4_1, y_s4_2, y_s4_3;
constexpr int32_t m0 = 0x05010400;
constexpr int32_t m1 = 0x05040100;
constexpr int32_t m2 = 0x07060302;
constexpr int32_t m3 = 0x07030602;
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) ->
// 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits
// first)
t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m0);
t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m0);
y_s4_0 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1);
y_s4_1 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2);
t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m3);
t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m3);
y_s4_2 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1);
y_s4_3 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2);
// 4 int8x4 data from vy_tuple
vy_tuple(iy).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_0);
vy_tuple(iy + I1).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_1);
vy_tuple(iy + I2).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_2);
vy_tuple(iy + I3).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_3);
});
});
}
else if constexpr(NX % 2 == 0 && NY % 2 == 0)
{
static_for<0, NY, 2>{}([&](auto ix) {
static_for<0, NX, 2>{}([&](auto iy) {
const int16_t x_s2_0 =
bit_cast<int16_t>(vx_tuple[ix].template get_as<S2>()[iy / I2]);
const int16_t x_s2_1 =
bit_cast<int16_t>(vx_tuple[ix + I1].template get_as<S2>()[iy / I2]);
constexpr int32_t m0 = 0x05040100;
constexpr int32_t m1 = 0x07060302;
const int32_t x0_32 = static_cast<int32_t>(x_s2_0 & 0xFFFF);
const int32_t x1_32 = static_cast<int32_t>(x_s2_1 & 0xFFFF);
const int32_t y_s2_0 = __builtin_amdgcn_perm(x1_32, x0_32, m0);
const int32_t y_s2_1 = __builtin_amdgcn_perm(x1_32, x0_32, m1);
vy_tuple(iy).template get_as<S2>()[ix / I2] =
bit_cast<S2>(static_cast<int16_t>(y_s2_0 & 0xFFFF));
vy_tuple(iy + I1).template get_as<S2>()[ix / I2] =
bit_cast<S2>(static_cast<int16_t>(y_s2_1 & 0xFFFF));
});
});
}
return bytesize2_2x2_tag{};
}
else if constexpr(sizeof(S) == 1 && NX % 4 == 0 && NY % 4 == 0)
{
return bytesize1_4x4_tag{};
}
else if constexpr(sizeof(S) == 1 && NX % 2 == 0 && NY % 2 == 0)
{
return bytesize1_2x2_tag{};
}
else
{
static_assert(false, "not implemented");
return generic_tag{};
}
}
CK_TILE_DEVICE void operator()(const thread_buffer<VX, NX>& vx_tuple,
thread_buffer<VY, NY>& vy_tuple) const
{
apply_impl(vx_tuple, vy_tuple, tag_dispatch());
}
};
} // namespace ck_tile