[CK_TILE] Use read_tr in universal gemm (#2436)

* Use read_tr in universal gemm

* Enable all instances back

* Revert example37 changes

* Resolve comments

* resolve comments 2

* Fix assertion msg

* fix the gemm basic

* change index_t to bool for preshuffle variable

* Solve the comment

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
Co-authored-by: AviralGoelAMD <aviral.goel@amd.com>

[ROCm/composable_kernel commit: f1d8ad2818]
This commit is contained in:
Yi DING
2025-07-17 14:56:22 +08:00
committed by GitHub
parent ac0a7f44e1
commit 3bb687d502
26 changed files with 1150 additions and 565 deletions

View File

@@ -2841,11 +2841,13 @@ __device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t> ||
std::is_same_v<remove_cvref_t<T>, ck_tile::bf8_t> ||
std::is_same_v<remove_cvref_t<T>, ck_tile::int8_t>)
{
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t;
__attribute__((address_space(3))) llvm_i32x2_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
}

View File

@@ -2611,11 +2611,13 @@ __device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t> ||
std::is_same_v<remove_cvref_t<T>, ck_tile::bf8_t> ||
std::is_same_v<remove_cvref_t<T>, ck_tile::int8_t>)
{
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t;
__attribute__((address_space(3))) llvm_i32x2_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
}

View File

@@ -10,53 +10,55 @@
namespace ck_tile {
// this generate wave level tile distribution
template <typename T, typename = void>
template <typename T, index_t LaneGroupSize = 16, typename = void>
struct LaneGroupTransposeTraits;
template <typename T>
struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 2>>
template <typename T, index_t LaneGroupSize>
struct LaneGroupTransposeTraits<T, LaneGroupSize, std::enable_if_t<sizeof(T) == 2>>
{
static_assert(LaneGroupSize == 16 || LaneGroupSize == 32 || LaneGroupSize == 64,
"LaneGroupSize must be 16, 32, or 64");
// before transpose, 4x16
static constexpr index_t ksecondDim = 4;
static constexpr index_t kleadDim = 16;
static constexpr index_t kleadDim = LaneGroupSize;
// after transpose, 16x4
static constexpr index_t ksecondDimT = 16;
static constexpr index_t ksecondDimT = LaneGroupSize;
static constexpr index_t kleadDimT = 4;
template <index_t kOuterDistDim0,
index_t kOuterDistDim1,
index_t kInnerDistDim0,
index_t kInnerDistDim1>
using TileDistribution =
tile_distribution_encoding<sequence<>,
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 4>,
sequence<kInnerDistDim0, kInnerDistDim1, 4, 4>>,
tuple<sequence<1, 2, 1, 2>>,
tuple<sequence<0, 0, 2, 2>>,
sequence<2, 1, 2>,
sequence<1, 1, 3>>;
using TileDistribution = tile_distribution_encoding<
sequence<>,
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 4>,
sequence<kInnerDistDim0, kInnerDistDim1, LaneGroupSize / 16, 4, 4>>,
tuple<sequence<1, 2, 2, 1, 2>>,
tuple<sequence<0, 0, 2, 2, 3>>,
sequence<2, 1, 2>,
sequence<1, 1, 4>>;
};
template <typename T>
struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 1>>
template <typename T, index_t LaneGroupSize>
struct LaneGroupTransposeTraits<T, LaneGroupSize, std::enable_if_t<sizeof(T) == 1>>
{
static constexpr index_t ksecondDim = 8;
static constexpr index_t kleadDim = 16;
static constexpr index_t kleadDim = LaneGroupSize;
static constexpr index_t ksecondDimT = 16;
static constexpr index_t ksecondDimT = LaneGroupSize;
static constexpr index_t kleadDimT = 8;
template <index_t kOuterDistDim0,
index_t kOuterDistDim1,
index_t kInnerDistDim0,
index_t kInnerDistDim1>
using TileDistribution =
tile_distribution_encoding<sequence<>,
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 8>,
sequence<kInnerDistDim0, kInnerDistDim1, 2, 8>>,
tuple<sequence<1, 2, 1, 2>>,
tuple<sequence<0, 0, 2, 2>>,
sequence<2, 1, 2>,
sequence<1, 1, 3>>;
using TileDistribution = tile_distribution_encoding<
sequence<>,
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 8>,
sequence<kInnerDistDim0, kInnerDistDim1, LaneGroupSize / 16, 2, 8>>,
tuple<sequence<1, 2, 2, 1, 2>>,
tuple<sequence<0, 0, 2, 2, 3>>,
sequence<2, 1, 2>,
sequence<1, 1, 4>>;
};
/*
@@ -72,15 +74,15 @@ struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 1>>
* consecutive.
*/
template <typename T,
index_t LaneGroupSize,
index_t kOuterDistDim0,
index_t kOuterDistDim1,
index_t kInnerDistDim0,
index_t kInnerDistDim1>
CK_TILE_DEVICE constexpr auto make_transposed_distr_encode()
{
using xdllevel_dstr_encoding = typename LaneGroupTransposeTraits<T>::
template TileDistribution<kOuterDistDim0, kOuterDistDim1, kInnerDistDim0, kInnerDistDim1>;
return xdllevel_dstr_encoding{};
return typename LaneGroupTransposeTraits<T, LaneGroupSize>::
template TileDistribution<kOuterDistDim0, kOuterDistDim1, kInnerDistDim0, kInnerDistDim1>{};
}
} // namespace ck_tile

View File

@@ -994,51 +994,34 @@ struct buffer_view<address_space_enum::lds,
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128
// TODO: remove this after compiler fix
// clang-format off
static_assert(
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x4_t> &&
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> &&
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x4_t> && std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x8_t> && std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> && std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
// int8 on thread buffer
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 16>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
// ext_vector_type for pk_int4 must use int8_t as type
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x16_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>),
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x16_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>),
"wrong! not implemented for this combination, please add "
"implementation");
// clang-format on
if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
@@ -1090,6 +1073,8 @@ struct buffer_view<address_space_enum::lds,
}
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 16>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>))
{

View File

@@ -17,6 +17,11 @@
namespace ck_tile {
constexpr int DS_READ_TR_SIZE()
{
return 8; // Literal constant, evaluated at compile time
}
namespace util {
template <typename Suffix, typename Sequence>
struct is_sequence_suffix
@@ -45,48 +50,60 @@ constexpr bool is_sequence_suffix_v = is_sequence_suffix<Suffix, Sequence>::valu
template <typename DataType>
struct DefaultTranspose
{
template <index_t LaneGroupSize>
struct Quad16
{
using InputEncoding = tile_distribution_encoding<sequence<>,
tuple<sequence<4>, sequence<4, 4>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16,
"LaneGroupSize must be 64, 32, or 16");
using InputEncoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<4>, sequence<LaneGroupSize / 16, 4, 4>>,
tuple<sequence<2, 1, 2>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<2>>;
using OutputEncoding = tile_distribution_encoding<sequence<>,
tuple<sequence<16>, sequence<4>>,
tuple<sequence<1>>,
tuple<sequence<0>>,
sequence<2>,
sequence<0>>;
using OutputEncoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<LaneGroupSize>, sequence<4>>,
tuple<sequence<1>>,
tuple<sequence<0>>,
sequence<2>,
sequence<0>>;
};
template <index_t LaneGroupSize>
struct Quad8
{
using InputEncoding = tile_distribution_encoding<sequence<>,
tuple<sequence<8>, sequence<2, 8>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16,
"LaneGroupSize must be 64, 32, or 16");
using InputEncoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<8>, sequence<LaneGroupSize / 16, 2, 8>>,
tuple<sequence<2, 1, 2>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<2>>;
using OutputEncoding = tile_distribution_encoding<sequence<>,
tuple<sequence<16>, sequence<8>>,
tuple<sequence<1>>,
tuple<sequence<0>>,
sequence<2>,
sequence<0>>;
using OutputEncoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<LaneGroupSize>, sequence<8>>,
tuple<sequence<1>>,
tuple<sequence<0>>,
sequence<2>,
sequence<0>>;
};
// Select based on data size
template <index_t LaneGroupSize>
using QuadInputEncoding = std::conditional_t<sizeof(DataType) == 2,
typename Quad16::InputEncoding,
typename Quad8::InputEncoding>;
typename Quad16<LaneGroupSize>::InputEncoding,
typename Quad8<LaneGroupSize>::InputEncoding>;
template <index_t LaneGroupSize>
using QuadOutputEncoding = std::conditional_t<sizeof(DataType) == 2,
typename Quad16::OutputEncoding,
typename Quad8::OutputEncoding>;
typename Quad16<LaneGroupSize>::OutputEncoding,
typename Quad8<LaneGroupSize>::OutputEncoding>;
// Always swap last two dimensions
static constexpr auto transpose_dims = sequence<1, 0>{};
@@ -96,51 +113,79 @@ struct DefaultTranspose
return idx; // Identity mapping
};
template <typename InDstrEncode>
struct ValidationTraits
template <typename InDstrEncode, bool ReverseDirection, index_t LaneGroupSize>
struct ValidationTraitsImpl
{
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
static constexpr auto quad_hs_lengthss = QuadInputEncoding::hs_lengthss_;
using QuadEncoding = std::conditional_t<ReverseDirection,
QuadOutputEncoding<LaneGroupSize>,
QuadInputEncoding<LaneGroupSize>>;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto input_hs = InDstrEncode::hs_lengthss_;
static constexpr auto quad_hs = QuadEncoding::hs_lengthss_;
// 1. Must be 2D tensor
static constexpr bool dims_valid = (InDstrEncode::NDimX == 2);
// 2. Quad pattern must be suffix of input pattern
static constexpr bool suffix_valid_dim0 =
util::is_sequence_suffix_v<decltype(quad_hs_lengthss.template get<0>()),
decltype(input_hs_lengthss.template get<0>())>;
util::is_sequence_suffix_v<decltype(quad_hs[I0]), decltype(input_hs[I0])>;
static constexpr bool suffix_valid_dim1 =
util::is_sequence_suffix_v<decltype(quad_hs_lengthss.template get<1>()),
decltype(input_hs_lengthss.template get<1>())>;
util::is_sequence_suffix_v<decltype(quad_hs[I1]), decltype(input_hs[I1])>;
// 3. PS→RHS mapping constraints
static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
static constexpr auto input_ps_major = InDstrEncode::ps_to_rhss_major_;
static constexpr auto input_ps_minor = InDstrEncode::ps_to_rhss_minor_;
static constexpr index_t ndimp_outer = input_ps_to_rhss_major.size() - 1;
static constexpr index_t ndimp_inner =
input_ps_to_rhss_major[number<ndimp_outer>{}].size() - 1;
static constexpr auto quad_ps_major0 = QuadEncoding::ps_to_rhss_major_[I0];
static constexpr auto quad_ps_minor0 = QuadEncoding::ps_to_rhss_minor_[I0];
static constexpr auto input_ps_major_last =
input_ps_major[number<input_ps_major.size() - 1>{}];
static constexpr auto input_ps_minor_last =
input_ps_minor[number<input_ps_minor.size() - 1>{}];
using psys_offset = ck_tile::sequence<input_hs[I0].size() - quad_hs[I0].size(),
input_hs[I1].size() - quad_hs[I1].size()>;
static constexpr auto shifted_quad_ps_minor0 = generate_sequence_v2(
[](auto i) {
return number<quad_ps_minor0[i] + psys_offset{}[quad_ps_major0[i] - 1]>{};
},
number<quad_ps_minor0.size()>{});
static constexpr bool ps_mapping_valid =
(input_ps_to_rhss_major[number<ndimp_outer>{}][number<ndimp_inner>{}] == 2) &&
(input_ps_to_rhss_minor[number<ndimp_outer>{}][number<ndimp_inner>{}] ==
input_hs_lengthss[number<1>{}].size() - 2) &&
(input_ps_to_rhss_major[number<ndimp_outer>{}][number<ndimp_inner - 1>{}] == 1) &&
(input_ps_to_rhss_minor[number<ndimp_outer>{}][number<ndimp_inner - 1>{}] ==
input_hs_lengthss[number<0>{}].size() - 1);
util::is_sequence_suffix_v<decltype(quad_ps_major0), decltype(input_ps_major_last)> &&
util::is_sequence_suffix_v<decltype(shifted_quad_ps_minor0),
decltype(input_ps_minor_last)>;
// 4. YS→RHS mapping constraints
static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
static constexpr auto input_ys_major = InDstrEncode::ys_to_rhs_major_;
static constexpr auto input_ys_minor = InDstrEncode::ys_to_rhs_minor_;
static constexpr auto quad_ys_major = QuadEncoding::ys_to_rhs_major_;
static constexpr auto quad_ys_minor = QuadEncoding::ys_to_rhs_minor_;
static_assert(quad_ys_major.size() == 1 && quad_ys_minor.size() == 1,
"YS->RHS mapping must be single dimension");
static_assert(quad_ys_major.back() == 2 && quad_ys_minor.back() == quad_hs[I1].size() - 1,
"YS->RHS mapping must be the last dimension");
static constexpr bool ys_mapping_valid =
(input_ys_to_rhs_major.back() == 2) &&
(input_ys_to_rhs_minor.back() == input_hs_lengthss[number<1>{}].size() - 1) &&
(input_ys_to_rhs_major[input_ys_to_rhs_major.size() - 2] == 1) &&
(input_ys_to_rhs_minor[input_ys_to_rhs_minor.size() - 2] ==
input_hs_lengthss[number<0>{}].size() - 2);
(input_ys_major.back() == 2) && (input_ys_minor.back() == input_hs[I1].size() - 1);
static constexpr bool value = dims_valid && suffix_valid_dim0 && suffix_valid_dim1 &&
ps_mapping_valid && ys_mapping_valid;
};
template <typename InDstrEncode, bool ReverseDirection = false>
struct ValidationTraits
{
static constexpr bool value =
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 64>::value ||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 32>::value ||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 16>::value;
static constexpr index_t LaneGroupSize =
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 64>::value ? 64
: ValidationTraitsImpl<InDstrEncode, ReverseDirection, 32>::value ? 32
: ValidationTraitsImpl<InDstrEncode, ReverseDirection, 16>::value ? 16
: 0;
};
};
template <typename TileDistribution_, typename DataType_, typename Policy>
struct TransposeTileDistrChecker
@@ -154,111 +199,150 @@ struct TransposeTileDistrChecker
// this is used to generate the transposed output tile distribution encoding
// based on the input tile distribution encoding
template <typename TileDistribution_,
template <typename TileDistributionEncoding_,
typename DataType_,
typename Policy = DefaultTranspose<DataType_>>
struct OutputTileDistributionTraits
typename Policy = DefaultTranspose<DataType_>,
bool ReverseDirection = false>
struct TransposeTileDistributionTraits
{
using InDstrEncode = typename remove_cvref_t<TileDistribution_>::DstrEncode;
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
static constexpr auto quad_input_hs_lengthss = Policy::QuadInputEncoding::hs_lengthss_;
static constexpr auto quad_output_hs_lengthss = Policy::QuadOutputEncoding::hs_lengthss_;
using InDstrEncode = remove_cvref_t<TileDistributionEncoding_>;
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
static constexpr index_t LaneGroupSize =
Policy::template ValidationTraits<InDstrEncode, ReverseDirection>::LaneGroupSize;
static_assert(Policy::template ValidationTraits<InDstrEncode, ReverseDirection>::value,
"The input tile distribution encoding is not valid for transpose!");
using QuadInputEncoding = std::conditional_t< //
ReverseDirection,
typename Policy::template QuadOutputEncoding<LaneGroupSize>,
typename Policy::template QuadInputEncoding<LaneGroupSize>>;
using QuadOutputEncoding = std::conditional_t< //
ReverseDirection,
typename Policy::template QuadInputEncoding<LaneGroupSize>,
typename Policy::template QuadOutputEncoding<LaneGroupSize>>;
static constexpr auto quad_input_hs_lengthss = QuadInputEncoding::hs_lengthss_;
static constexpr auto quad_output_hs_lengthss = QuadOutputEncoding::hs_lengthss_;
static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
static constexpr auto quad_ps_to_rhss_major = Policy::QuadInputEncoding::ps_to_rhss_major_;
static constexpr auto quad_ps_to_rhss_minor = Policy::QuadInputEncoding::ps_to_rhss_minor_;
static constexpr auto I0 = number<0>{};
static constexpr auto quad_input_ps_to_rhss_major0 = QuadInputEncoding::ps_to_rhss_major_[I0];
static constexpr auto quad_input_ps_to_rhss_minor0 = QuadInputEncoding::ps_to_rhss_minor_[I0];
static constexpr auto quad_output_ps_to_rhss_major0 = QuadOutputEncoding::ps_to_rhss_major_[I0];
static constexpr auto quad_output_ps_to_rhss_minor0 = QuadOutputEncoding::ps_to_rhss_minor_[I0];
static constexpr auto quad_output_ys_to_rhs_major = QuadOutputEncoding::ys_to_rhs_major_;
static constexpr auto quad_output_ys_to_rhs_minor = QuadOutputEncoding::ys_to_rhs_minor_;
static constexpr index_t dim0 = Policy::transpose_dims[0];
static constexpr index_t dim1 = Policy::transpose_dims[1];
static constexpr auto swap_one_and_two = [](const index_t idx) {
return (idx == 1) ? 2 : (idx == 2) ? 1 : idx;
};
// for transpose load
// append the reversed quad output hs lengths to the input hs lengthss after removing
// the quad_input_hs_lengthss
// then reverse the whole sequence to get the dst_out_hs_lengthss
static constexpr auto reversed_quad_output_hs_lengthss = tuple_reverse(quad_output_hs_lengthss);
static constexpr auto full_out_hs_lengthss = generate_tuple(
// remove the quad_input_hs_lengthss from the input_hs_lengthss for each dimension and reverse
// dims and append the quad_output_hs_lengthss to the end of each dimension
static constexpr auto outer_hs_lengthss = generate_tuple(
[](auto i) {
return input_hs_lengthss[i]
.extract(typename arithmetic_sequence_gen<0,
input_hs_lengthss[i].size() -
quad_input_hs_lengthss[i].size(),
1>::type{})
.push_back(reversed_quad_output_hs_lengthss[i]);
constexpr auto input_i = input_hs_lengthss[i];
constexpr auto outer_len = input_i.size() - quad_input_hs_lengthss[i].size();
return typename sequence_split<decltype(input_i), outer_len>::left_type{};
},
number<InDstrEncode::NDimX>{});
static constexpr auto reversed_outer_hs_lengthss = tuple_reverse(outer_hs_lengthss);
static constexpr auto dst_out_hs_lengthss = generate_tuple(
[](auto i) {
auto outer_i = reversed_outer_hs_lengthss[i];
// append the reversed quad output hs lengths to the outer hs lengths
return outer_i.push_back(quad_output_hs_lengthss[i]);
},
number<InDstrEncode::NDimX>{});
static constexpr auto dst_out_hs_lengthss = tuple_reverse(full_out_hs_lengthss);
// for PS→RHS mapping(both major and minor), we need to modify the last element of the major
// sequence
static constexpr auto modified_ps_to_rhss_major = generate_tuple(
// for PS→RHS mapping(both major and minor), we need to modify the last element (which is for
// thread distr) of the major sequence
static constexpr auto dst_ps_to_rhss_major = generate_tuple(
// for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed
[](auto i) {
if constexpr(i == input_ps_to_rhss_major.size() - 1)
{
constexpr auto current_size = input_ps_to_rhss_major[i].size();
constexpr auto reduce_size = quad_ps_to_rhss_major[number<0>{}].size();
constexpr auto reduce_size = quad_input_ps_to_rhss_major0.size();
constexpr auto quad_out = quad_output_ps_to_rhss_major0;
constexpr auto reduced_ps_to_rhss_major = input_ps_to_rhss_major[i].extract(
typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{});
return reduced_ps_to_rhss_major.push_back(number<2>{});
return reduced_ps_to_rhss_major.transform(swap_one_and_two).push_back(quad_out);
}
else
{
// For all other sequences, keep them unchanged
return input_ps_to_rhss_major[i];
// For all other sequences (i.e. warp), keep them unchanged
return input_ps_to_rhss_major[i].transform(swap_one_and_two);
}
},
number<input_ps_to_rhss_major.size()>{});
static constexpr auto minor_last_index =
full_out_hs_lengthss[number<InDstrEncode::NDimX - 1>{}].size() - 1;
static constexpr auto major_last_index = full_out_hs_lengthss[number<0>{}].size() - 1;
static constexpr auto quad_idx_offset =
transform_tuples([](auto x) { return number<x.size()>{}; }, reversed_outer_hs_lengthss);
// minus 1 because RsLength is not counted
static constexpr auto quad_output_ps_minor_offset = to_sequence(generate_tuple_for(
[](auto x) { return quad_idx_offset[number<x - 1>{}]; }, quad_output_ps_to_rhss_major0));
static constexpr auto quad_output_ys_minor_offset = to_sequence(generate_tuple_for(
[](auto x) { return quad_idx_offset[number<x - 1>{}]; }, quad_output_ys_to_rhs_major));
static constexpr auto dst_ps_to_rhss_minor = generate_tuple(
[](auto i) {
constexpr auto input_i = input_ps_to_rhss_minor[i];
if constexpr(i == input_ps_to_rhss_minor.size() - 1)
{
constexpr auto current_size = input_ps_to_rhss_minor[i].size();
constexpr auto reduce_size = quad_ps_to_rhss_minor[number<0>{}].size();
constexpr auto reduced_ps_to_rhss_minor = input_ps_to_rhss_minor[i].extract(
typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{});
return reduced_ps_to_rhss_minor.push_back(number<minor_last_index>{});
constexpr auto outer_len = input_i.size() - quad_input_ps_to_rhss_minor0.size();
constexpr auto outer_ps =
typename sequence_split<decltype(input_i), outer_len>::left_type{};
return outer_ps.push_back(quad_output_ps_minor_offset +
quad_output_ps_to_rhss_minor0);
}
else
{
// For all other sequences, keep them unchanged
return input_ps_to_rhss_minor[i];
return input_i;
}
},
number<input_ps_to_rhss_minor.size()>{});
static constexpr auto outer_input_ys_to_rhs_major = input_ys_to_rhs_major.pop_back();
// for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed
static constexpr auto swap_one_and_two = [](const index_t idx) {
return (idx == 1) ? 2 : (idx == 2) ? 1 : idx;
};
static constexpr auto dst_ps_to_rhss_major = generate_tuple(
[](auto i) { return modified_ps_to_rhss_major[i].transform(swap_one_and_two); },
number<modified_ps_to_rhss_major.size()>{});
static constexpr auto dst_ys_to_rhs_major =
outer_input_ys_to_rhs_major.transform(swap_one_and_two).push_back(number<2>{});
static constexpr auto modified_input_ys_to_rhs_major =
input_ys_to_rhs_major.pop_back().push_back(number<1>{});
static constexpr auto dst_ys_to_rhs_minor = input_ys_to_rhs_minor.pop_back().push_back(
number<(quad_output_ys_minor_offset + quad_output_ys_to_rhs_minor)[I0]>{});
static constexpr auto dst_ys_to_rhs_major = generate_sequence_v2(
[](auto i) { return number<swap_one_and_two(modified_input_ys_to_rhs_major[i])>{}; },
number<modified_input_ys_to_rhs_major.size()>{});
static constexpr auto dst_ys_to_rhs_minor =
input_ys_to_rhs_minor.pop_back().push_back(number<major_last_index>{});
using OutDstrEncode = tile_distribution_encoding<typename InDstrEncode::RsLengths,
remove_cvref_t<decltype(dst_out_hs_lengthss)>,
remove_cvref_t<decltype(dst_ps_to_rhss_major)>,
remove_cvref_t<decltype(dst_ps_to_rhss_minor)>,
remove_cvref_t<decltype(dst_ys_to_rhs_major)>,
remove_cvref_t<decltype(dst_ys_to_rhs_minor)>>;
using TransposedDstrEncode =
tile_distribution_encoding<typename InDstrEncode::RsLengths,
remove_cvref_t<decltype(dst_out_hs_lengthss)>,
remove_cvref_t<decltype(dst_ps_to_rhss_major)>,
remove_cvref_t<decltype(dst_ps_to_rhss_minor)>,
remove_cvref_t<decltype(dst_ys_to_rhs_major)>,
remove_cvref_t<decltype(dst_ys_to_rhs_minor)>>;
};
template <typename TileDistributionEncoding_,
typename DataType_,
typename Policy = DefaultTranspose<DataType_>>
using OutputTileDistributionTraits =
TransposeTileDistributionTraits<TileDistributionEncoding_, DataType_, Policy, false>;
template <typename TileDistributionEncoding_,
typename DataType_,
typename Policy = DefaultTranspose<DataType_>>
using InputTileDistributionTraits =
TransposeTileDistributionTraits<TileDistributionEncoding_, DataType_, Policy, true>;
template <typename InnerEncode,
index_t kLeadIterPerWarp,
index_t kSecondIterPerWarp,
@@ -321,9 +405,9 @@ load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_
TileDistribution_,
NumCoord>& tile_window)
{
using OutTileDstrEncode =
typename OutputTileDistributionTraits<TileDistribution_,
typename BottomTensorView_::DataType>::OutDstrEncode;
using OutTileDstrEncode = typename OutputTileDistributionTraits<
typename TileDistribution_::DstrEncode,
typename BottomTensorView_::DataType>::TransposedDstrEncode;
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
make_static_tile_distribution(OutTileDstrEncode{}));
auto trans_tensor = tile_window.template load_transpose<Policy>();

View File

@@ -0,0 +1,156 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdio.h>
#include <tuple>
#include <utility>
#include "ck_tile/core/numeric/integer.hpp"
namespace ck_tile {
template <auto... val>
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
{
}
template <typename... type>
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
{
}
template <char... Xs>
struct str_literal
{
static constexpr const char data[] = {Xs..., '\0'};
static constexpr const size_t size = sizeof...(Xs);
template <char... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator+(str_literal<Ys...> /*rhs*/) const
{
return str_literal<Xs..., Ys...>{};
}
template <index_t N, char... Ys>
CK_TILE_HOST_DEVICE static constexpr auto duplicate_n(const str_literal<Ys...> sep)
{
if constexpr(N == 0)
return str_literal<>{};
else if constexpr(N == 1)
return str_literal<Xs...>{};
else
return duplicate_n<N - 1>(sep) + str_literal<Ys..., Xs...>{};
}
};
#define make_str_literal(lit_) \
std::apply([](auto... indices) { return str_literal<(lit_)[decltype(indices)::value]...>{}; }, \
makeTuple(std::make_index_sequence<constexpr_strlen(lit_)>()))
template <size_t... Idx>
constexpr std::tuple<std::integral_constant<size_t, Idx>...>
makeTuple(std::index_sequence<Idx...>) noexcept
{
return {};
}
constexpr size_t constexpr_strlen(const char* c)
{
size_t t = 0;
while(*c++)
++t;
return t;
}
template <typename DataType_, typename StaticTileDistribution_>
struct static_distributed_tensor;
template <typename T_, index_t N_>
struct thread_buffer;
// Usage example: CK_PRINTF<float>{}(tensor);
template <typename ConvertTo = void,
typename FMT = str_literal<>,
typename PREFIX = str_literal<>,
typename SUFFIX = str_literal<>>
struct CK_PRINTF;
template <typename ConvertTo, char... FMTChars, char... PREFIXChars, char... SUFFIXChars>
struct CK_PRINTF<ConvertTo,
str_literal<FMTChars...>,
str_literal<PREFIXChars...>,
str_literal<SUFFIXChars...>>
{
template <typename T>
CK_TILE_HOST_DEVICE static constexpr auto default_format()
{
if constexpr(std::is_same_v<T, float>)
return make_str_literal("%8.3f");
else if constexpr(std::is_same_v<T, int>)
return make_str_literal("%5d");
else if constexpr(std::is_same_v<T, unsigned int>)
return make_str_literal("%5u");
else
return make_str_literal("0x%08x");
}
CK_TILE_HOST_DEVICE static constexpr auto get_prefix()
{
constexpr auto fmt_tid = make_str_literal("tid %03d: [%02d] ");
if constexpr(sizeof...(PREFIXChars) == 0)
return fmt_tid;
else
return fmt_tid + make_str_literal(" ") + str_literal<PREFIXChars...>{};
}
CK_TILE_HOST_DEVICE static constexpr auto get_suffix()
{
constexpr auto lf = make_str_literal("\n");
if constexpr(sizeof...(SUFFIXChars) == 0)
return lf;
else
return str_literal<SUFFIXChars...>{} + lf;
}
template <typename T, index_t N, typename Y, index_t... Is>
CK_TILE_HOST_DEVICE void impl(const thread_buffer<T, N>& buf,
std::integer_sequence<index_t, Is...>) const
{
using FMT1 = std::conditional_t<sizeof...(FMTChars) == 0,
decltype(default_format<Y>()),
str_literal<FMTChars...>>;
constexpr auto fmt_v = FMT1::template duplicate_n<N>(make_str_literal(" "));
constexpr auto fmt_wrap_v = get_prefix() + fmt_v + get_suffix();
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
printf(fmt_wrap_v.data, get_thread_id(), N, type_convert<Y>(buf[Is])...);
#pragma clang diagnostic pop
}
template <typename T, index_t N>
CK_TILE_HOST_DEVICE void operator()(const thread_buffer<T, N>& buf) const
{
using ConvertTo_ = std::conditional_t<std::is_same_v<ConvertTo, void>, T, ConvertTo>;
impl<T, N, ConvertTo_>(buf, std::make_integer_sequence<index_t, N>{});
}
template <typename... TS>
CK_TILE_HOST_DEVICE void operator()(const static_distributed_tensor<TS...>& tensor) const
{
return operator()(tensor.get_thread_buffer());
}
};
template <typename ConvertTo = void,
typename FMT = str_literal<>,
typename PREFIX = str_literal<>,
typename SUFFIX = str_literal<>>
struct CK_PRINTF_WARP0 : public CK_PRINTF<ConvertTo, FMT, PREFIX, SUFFIX>
{
using base_t = CK_PRINTF<ConvertTo, FMT, PREFIX, SUFFIX>;
template <typename T>
CK_TILE_HOST_DEVICE void operator()(const T& buf) const
{
if(get_thread_id() < get_warp_size())
base_t::operator()(buf);
}
};
} // namespace ck_tile