mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
[CK_TILE] Use read_tr in universal gemm (#2436)
* Use read_tr in universal gemm
* Enable all instances back
* Revert example37 changes
* Resolve comments
* resolve comments 2
* Fix assertion msg
* fix the gemm basic
* change index_t to bool for preshuffle variable
* Solve the comment
---------
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
Co-authored-by: AviralGoelAMD <aviral.goel@amd.com>
[ROCm/composable_kernel commit: f1d8ad2818]
This commit is contained in:
@@ -2841,11 +2841,13 @@ __device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t> ||
|
||||
std::is_same_v<remove_cvref_t<T>, ck_tile::bf8_t> ||
|
||||
std::is_same_v<remove_cvref_t<T>, ck_tile::int8_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
|
||||
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
|
||||
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t;
|
||||
__attribute__((address_space(3))) llvm_i32x2_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
|
||||
}
|
||||
|
||||
@@ -2611,11 +2611,13 @@ __device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t> ||
|
||||
std::is_same_v<remove_cvref_t<T>, ck_tile::bf8_t> ||
|
||||
std::is_same_v<remove_cvref_t<T>, ck_tile::int8_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
|
||||
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
|
||||
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t;
|
||||
__attribute__((address_space(3))) llvm_i32x2_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
|
||||
}
|
||||
|
||||
@@ -10,53 +10,55 @@
|
||||
namespace ck_tile {
|
||||
|
||||
// this generate wave level tile distribution
|
||||
template <typename T, typename = void>
|
||||
template <typename T, index_t LaneGroupSize = 16, typename = void>
|
||||
struct LaneGroupTransposeTraits;
|
||||
|
||||
template <typename T>
|
||||
struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 2>>
|
||||
template <typename T, index_t LaneGroupSize>
|
||||
struct LaneGroupTransposeTraits<T, LaneGroupSize, std::enable_if_t<sizeof(T) == 2>>
|
||||
{
|
||||
static_assert(LaneGroupSize == 16 || LaneGroupSize == 32 || LaneGroupSize == 64,
|
||||
"LaneGroupSize must be 16, 32, or 64");
|
||||
// before transpose, 4x16
|
||||
static constexpr index_t ksecondDim = 4;
|
||||
static constexpr index_t kleadDim = 16;
|
||||
static constexpr index_t kleadDim = LaneGroupSize;
|
||||
// after transpose, 16x4
|
||||
static constexpr index_t ksecondDimT = 16;
|
||||
static constexpr index_t ksecondDimT = LaneGroupSize;
|
||||
static constexpr index_t kleadDimT = 4;
|
||||
template <index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
using TileDistribution =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 4>,
|
||||
sequence<kInnerDistDim0, kInnerDistDim1, 4, 4>>,
|
||||
tuple<sequence<1, 2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 2, 2>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<1, 1, 3>>;
|
||||
using TileDistribution = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 4>,
|
||||
sequence<kInnerDistDim0, kInnerDistDim1, LaneGroupSize / 16, 4, 4>>,
|
||||
tuple<sequence<1, 2, 2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 2, 2, 3>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<1, 1, 4>>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 1>>
|
||||
template <typename T, index_t LaneGroupSize>
|
||||
struct LaneGroupTransposeTraits<T, LaneGroupSize, std::enable_if_t<sizeof(T) == 1>>
|
||||
{
|
||||
static constexpr index_t ksecondDim = 8;
|
||||
static constexpr index_t kleadDim = 16;
|
||||
static constexpr index_t kleadDim = LaneGroupSize;
|
||||
|
||||
static constexpr index_t ksecondDimT = 16;
|
||||
static constexpr index_t ksecondDimT = LaneGroupSize;
|
||||
static constexpr index_t kleadDimT = 8;
|
||||
|
||||
template <index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
using TileDistribution =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 8>,
|
||||
sequence<kInnerDistDim0, kInnerDistDim1, 2, 8>>,
|
||||
tuple<sequence<1, 2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 2, 2>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<1, 1, 3>>;
|
||||
using TileDistribution = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 8>,
|
||||
sequence<kInnerDistDim0, kInnerDistDim1, LaneGroupSize / 16, 2, 8>>,
|
||||
tuple<sequence<1, 2, 2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 2, 2, 3>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<1, 1, 4>>;
|
||||
};
|
||||
|
||||
/*
|
||||
@@ -72,15 +74,15 @@ struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 1>>
|
||||
* consecutive.
|
||||
*/
|
||||
template <typename T,
|
||||
index_t LaneGroupSize,
|
||||
index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
CK_TILE_DEVICE constexpr auto make_transposed_distr_encode()
|
||||
{
|
||||
using xdllevel_dstr_encoding = typename LaneGroupTransposeTraits<T>::
|
||||
template TileDistribution<kOuterDistDim0, kOuterDistDim1, kInnerDistDim0, kInnerDistDim1>;
|
||||
return xdllevel_dstr_encoding{};
|
||||
return typename LaneGroupTransposeTraits<T, LaneGroupSize>::
|
||||
template TileDistribution<kOuterDistDim0, kOuterDistDim1, kInnerDistDim0, kInnerDistDim1>{};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -994,51 +994,34 @@ struct buffer_view<address_space_enum::lds,
|
||||
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
|
||||
// ds_write_b128
|
||||
// TODO: remove this after compiler fix
|
||||
// clang-format off
|
||||
static_assert(
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x4_t> && std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x8_t> && std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> && std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
// int8 on thread buffer
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 16>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
|
||||
// ext_vector_type for pk_int4 must use int8_t as type
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x16_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>),
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x16_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>),
|
||||
"wrong! not implemented for this combination, please add "
|
||||
"implementation");
|
||||
// clang-format on
|
||||
|
||||
if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
|
||||
@@ -1090,6 +1073,8 @@ struct buffer_view<address_space_enum::lds,
|
||||
}
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 16>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>))
|
||||
{
|
||||
|
||||
@@ -17,6 +17,11 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
constexpr int DS_READ_TR_SIZE()
|
||||
{
|
||||
return 8; // Literal constant, evaluated at compile time
|
||||
}
|
||||
|
||||
namespace util {
|
||||
template <typename Suffix, typename Sequence>
|
||||
struct is_sequence_suffix
|
||||
@@ -45,48 +50,60 @@ constexpr bool is_sequence_suffix_v = is_sequence_suffix<Suffix, Sequence>::valu
|
||||
template <typename DataType>
|
||||
struct DefaultTranspose
|
||||
{
|
||||
template <index_t LaneGroupSize>
|
||||
struct Quad16
|
||||
{
|
||||
using InputEncoding = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<4>, sequence<4, 4>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16,
|
||||
"LaneGroupSize must be 64, 32, or 16");
|
||||
using InputEncoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<4>, sequence<LaneGroupSize / 16, 4, 4>>,
|
||||
tuple<sequence<2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<2>>;
|
||||
|
||||
using OutputEncoding = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<16>, sequence<4>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>;
|
||||
using OutputEncoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<LaneGroupSize>, sequence<4>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>;
|
||||
};
|
||||
|
||||
template <index_t LaneGroupSize>
|
||||
struct Quad8
|
||||
{
|
||||
using InputEncoding = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<8>, sequence<2, 8>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16,
|
||||
"LaneGroupSize must be 64, 32, or 16");
|
||||
using InputEncoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<8>, sequence<LaneGroupSize / 16, 2, 8>>,
|
||||
tuple<sequence<2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<2>>;
|
||||
|
||||
using OutputEncoding = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<16>, sequence<8>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>;
|
||||
using OutputEncoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<LaneGroupSize>, sequence<8>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>;
|
||||
};
|
||||
|
||||
// Select based on data size
|
||||
template <index_t LaneGroupSize>
|
||||
using QuadInputEncoding = std::conditional_t<sizeof(DataType) == 2,
|
||||
typename Quad16::InputEncoding,
|
||||
typename Quad8::InputEncoding>;
|
||||
typename Quad16<LaneGroupSize>::InputEncoding,
|
||||
typename Quad8<LaneGroupSize>::InputEncoding>;
|
||||
|
||||
template <index_t LaneGroupSize>
|
||||
using QuadOutputEncoding = std::conditional_t<sizeof(DataType) == 2,
|
||||
typename Quad16::OutputEncoding,
|
||||
typename Quad8::OutputEncoding>;
|
||||
typename Quad16<LaneGroupSize>::OutputEncoding,
|
||||
typename Quad8<LaneGroupSize>::OutputEncoding>;
|
||||
|
||||
// Always swap last two dimensions
|
||||
static constexpr auto transpose_dims = sequence<1, 0>{};
|
||||
@@ -96,51 +113,79 @@ struct DefaultTranspose
|
||||
return idx; // Identity mapping
|
||||
};
|
||||
|
||||
template <typename InDstrEncode>
|
||||
struct ValidationTraits
|
||||
template <typename InDstrEncode, bool ReverseDirection, index_t LaneGroupSize>
|
||||
struct ValidationTraitsImpl
|
||||
{
|
||||
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
|
||||
static constexpr auto quad_hs_lengthss = QuadInputEncoding::hs_lengthss_;
|
||||
using QuadEncoding = std::conditional_t<ReverseDirection,
|
||||
QuadOutputEncoding<LaneGroupSize>,
|
||||
QuadInputEncoding<LaneGroupSize>>;
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto input_hs = InDstrEncode::hs_lengthss_;
|
||||
static constexpr auto quad_hs = QuadEncoding::hs_lengthss_;
|
||||
// 1. Must be 2D tensor
|
||||
static constexpr bool dims_valid = (InDstrEncode::NDimX == 2);
|
||||
// 2. Quad pattern must be suffix of input pattern
|
||||
static constexpr bool suffix_valid_dim0 =
|
||||
util::is_sequence_suffix_v<decltype(quad_hs_lengthss.template get<0>()),
|
||||
decltype(input_hs_lengthss.template get<0>())>;
|
||||
util::is_sequence_suffix_v<decltype(quad_hs[I0]), decltype(input_hs[I0])>;
|
||||
static constexpr bool suffix_valid_dim1 =
|
||||
util::is_sequence_suffix_v<decltype(quad_hs_lengthss.template get<1>()),
|
||||
decltype(input_hs_lengthss.template get<1>())>;
|
||||
util::is_sequence_suffix_v<decltype(quad_hs[I1]), decltype(input_hs[I1])>;
|
||||
|
||||
// 3. PS→RHS mapping constraints
|
||||
static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
|
||||
static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
|
||||
static constexpr auto input_ps_major = InDstrEncode::ps_to_rhss_major_;
|
||||
static constexpr auto input_ps_minor = InDstrEncode::ps_to_rhss_minor_;
|
||||
|
||||
static constexpr index_t ndimp_outer = input_ps_to_rhss_major.size() - 1;
|
||||
static constexpr index_t ndimp_inner =
|
||||
input_ps_to_rhss_major[number<ndimp_outer>{}].size() - 1;
|
||||
static constexpr auto quad_ps_major0 = QuadEncoding::ps_to_rhss_major_[I0];
|
||||
static constexpr auto quad_ps_minor0 = QuadEncoding::ps_to_rhss_minor_[I0];
|
||||
|
||||
static constexpr auto input_ps_major_last =
|
||||
input_ps_major[number<input_ps_major.size() - 1>{}];
|
||||
static constexpr auto input_ps_minor_last =
|
||||
input_ps_minor[number<input_ps_minor.size() - 1>{}];
|
||||
|
||||
using psys_offset = ck_tile::sequence<input_hs[I0].size() - quad_hs[I0].size(),
|
||||
input_hs[I1].size() - quad_hs[I1].size()>;
|
||||
static constexpr auto shifted_quad_ps_minor0 = generate_sequence_v2(
|
||||
[](auto i) {
|
||||
return number<quad_ps_minor0[i] + psys_offset{}[quad_ps_major0[i] - 1]>{};
|
||||
},
|
||||
number<quad_ps_minor0.size()>{});
|
||||
|
||||
static constexpr bool ps_mapping_valid =
|
||||
(input_ps_to_rhss_major[number<ndimp_outer>{}][number<ndimp_inner>{}] == 2) &&
|
||||
(input_ps_to_rhss_minor[number<ndimp_outer>{}][number<ndimp_inner>{}] ==
|
||||
input_hs_lengthss[number<1>{}].size() - 2) &&
|
||||
(input_ps_to_rhss_major[number<ndimp_outer>{}][number<ndimp_inner - 1>{}] == 1) &&
|
||||
(input_ps_to_rhss_minor[number<ndimp_outer>{}][number<ndimp_inner - 1>{}] ==
|
||||
input_hs_lengthss[number<0>{}].size() - 1);
|
||||
util::is_sequence_suffix_v<decltype(quad_ps_major0), decltype(input_ps_major_last)> &&
|
||||
util::is_sequence_suffix_v<decltype(shifted_quad_ps_minor0),
|
||||
decltype(input_ps_minor_last)>;
|
||||
|
||||
// 4. YS→RHS mapping constraints
|
||||
static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
|
||||
static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
|
||||
static constexpr auto input_ys_major = InDstrEncode::ys_to_rhs_major_;
|
||||
static constexpr auto input_ys_minor = InDstrEncode::ys_to_rhs_minor_;
|
||||
static constexpr auto quad_ys_major = QuadEncoding::ys_to_rhs_major_;
|
||||
static constexpr auto quad_ys_minor = QuadEncoding::ys_to_rhs_minor_;
|
||||
|
||||
static_assert(quad_ys_major.size() == 1 && quad_ys_minor.size() == 1,
|
||||
"YS->RHS mapping must be single dimension");
|
||||
static_assert(quad_ys_major.back() == 2 && quad_ys_minor.back() == quad_hs[I1].size() - 1,
|
||||
"YS->RHS mapping must be the last dimension");
|
||||
static constexpr bool ys_mapping_valid =
|
||||
(input_ys_to_rhs_major.back() == 2) &&
|
||||
(input_ys_to_rhs_minor.back() == input_hs_lengthss[number<1>{}].size() - 1) &&
|
||||
(input_ys_to_rhs_major[input_ys_to_rhs_major.size() - 2] == 1) &&
|
||||
(input_ys_to_rhs_minor[input_ys_to_rhs_minor.size() - 2] ==
|
||||
input_hs_lengthss[number<0>{}].size() - 2);
|
||||
(input_ys_major.back() == 2) && (input_ys_minor.back() == input_hs[I1].size() - 1);
|
||||
|
||||
static constexpr bool value = dims_valid && suffix_valid_dim0 && suffix_valid_dim1 &&
|
||||
ps_mapping_valid && ys_mapping_valid;
|
||||
};
|
||||
|
||||
template <typename InDstrEncode, bool ReverseDirection = false>
|
||||
struct ValidationTraits
|
||||
{
|
||||
static constexpr bool value =
|
||||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 64>::value ||
|
||||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 32>::value ||
|
||||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 16>::value;
|
||||
static constexpr index_t LaneGroupSize =
|
||||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 64>::value ? 64
|
||||
: ValidationTraitsImpl<InDstrEncode, ReverseDirection, 32>::value ? 32
|
||||
: ValidationTraitsImpl<InDstrEncode, ReverseDirection, 16>::value ? 16
|
||||
: 0;
|
||||
};
|
||||
};
|
||||
template <typename TileDistribution_, typename DataType_, typename Policy>
|
||||
struct TransposeTileDistrChecker
|
||||
@@ -154,111 +199,150 @@ struct TransposeTileDistrChecker
|
||||
|
||||
// this is used to generate the transposed output tile distribution encoding
|
||||
// based on the input tile distribution encoding
|
||||
template <typename TileDistribution_,
|
||||
template <typename TileDistributionEncoding_,
|
||||
typename DataType_,
|
||||
typename Policy = DefaultTranspose<DataType_>>
|
||||
struct OutputTileDistributionTraits
|
||||
typename Policy = DefaultTranspose<DataType_>,
|
||||
bool ReverseDirection = false>
|
||||
struct TransposeTileDistributionTraits
|
||||
{
|
||||
using InDstrEncode = typename remove_cvref_t<TileDistribution_>::DstrEncode;
|
||||
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
|
||||
static constexpr auto quad_input_hs_lengthss = Policy::QuadInputEncoding::hs_lengthss_;
|
||||
static constexpr auto quad_output_hs_lengthss = Policy::QuadOutputEncoding::hs_lengthss_;
|
||||
using InDstrEncode = remove_cvref_t<TileDistributionEncoding_>;
|
||||
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
|
||||
static constexpr index_t LaneGroupSize =
|
||||
Policy::template ValidationTraits<InDstrEncode, ReverseDirection>::LaneGroupSize;
|
||||
static_assert(Policy::template ValidationTraits<InDstrEncode, ReverseDirection>::value,
|
||||
"The input tile distribution encoding is not valid for transpose!");
|
||||
|
||||
using QuadInputEncoding = std::conditional_t< //
|
||||
ReverseDirection,
|
||||
typename Policy::template QuadOutputEncoding<LaneGroupSize>,
|
||||
typename Policy::template QuadInputEncoding<LaneGroupSize>>;
|
||||
using QuadOutputEncoding = std::conditional_t< //
|
||||
ReverseDirection,
|
||||
typename Policy::template QuadInputEncoding<LaneGroupSize>,
|
||||
typename Policy::template QuadOutputEncoding<LaneGroupSize>>;
|
||||
|
||||
static constexpr auto quad_input_hs_lengthss = QuadInputEncoding::hs_lengthss_;
|
||||
static constexpr auto quad_output_hs_lengthss = QuadOutputEncoding::hs_lengthss_;
|
||||
|
||||
static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
|
||||
static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
|
||||
static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
|
||||
static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
|
||||
|
||||
static constexpr auto quad_ps_to_rhss_major = Policy::QuadInputEncoding::ps_to_rhss_major_;
|
||||
static constexpr auto quad_ps_to_rhss_minor = Policy::QuadInputEncoding::ps_to_rhss_minor_;
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto quad_input_ps_to_rhss_major0 = QuadInputEncoding::ps_to_rhss_major_[I0];
|
||||
static constexpr auto quad_input_ps_to_rhss_minor0 = QuadInputEncoding::ps_to_rhss_minor_[I0];
|
||||
static constexpr auto quad_output_ps_to_rhss_major0 = QuadOutputEncoding::ps_to_rhss_major_[I0];
|
||||
static constexpr auto quad_output_ps_to_rhss_minor0 = QuadOutputEncoding::ps_to_rhss_minor_[I0];
|
||||
static constexpr auto quad_output_ys_to_rhs_major = QuadOutputEncoding::ys_to_rhs_major_;
|
||||
static constexpr auto quad_output_ys_to_rhs_minor = QuadOutputEncoding::ys_to_rhs_minor_;
|
||||
|
||||
static constexpr index_t dim0 = Policy::transpose_dims[0];
|
||||
static constexpr index_t dim1 = Policy::transpose_dims[1];
|
||||
|
||||
static constexpr auto swap_one_and_two = [](const index_t idx) {
|
||||
return (idx == 1) ? 2 : (idx == 2) ? 1 : idx;
|
||||
};
|
||||
|
||||
// for transpose load
|
||||
// append the reversed quad output hs lengths to the input hs lengthss after removing
|
||||
// the quad_input_hs_lengthss
|
||||
// then reverse the whole sequence to get the dst_out_hs_lengthss
|
||||
static constexpr auto reversed_quad_output_hs_lengthss = tuple_reverse(quad_output_hs_lengthss);
|
||||
|
||||
static constexpr auto full_out_hs_lengthss = generate_tuple(
|
||||
// remove the quad_input_hs_lengthss from the input_hs_lengthss for each dimension and reverse
|
||||
// dims and append the quad_output_hs_lengthss to the end of each dimension
|
||||
static constexpr auto outer_hs_lengthss = generate_tuple(
|
||||
[](auto i) {
|
||||
return input_hs_lengthss[i]
|
||||
.extract(typename arithmetic_sequence_gen<0,
|
||||
input_hs_lengthss[i].size() -
|
||||
quad_input_hs_lengthss[i].size(),
|
||||
1>::type{})
|
||||
.push_back(reversed_quad_output_hs_lengthss[i]);
|
||||
constexpr auto input_i = input_hs_lengthss[i];
|
||||
constexpr auto outer_len = input_i.size() - quad_input_hs_lengthss[i].size();
|
||||
return typename sequence_split<decltype(input_i), outer_len>::left_type{};
|
||||
},
|
||||
number<InDstrEncode::NDimX>{});
|
||||
static constexpr auto reversed_outer_hs_lengthss = tuple_reverse(outer_hs_lengthss);
|
||||
static constexpr auto dst_out_hs_lengthss = generate_tuple(
|
||||
[](auto i) {
|
||||
auto outer_i = reversed_outer_hs_lengthss[i];
|
||||
// append the reversed quad output hs lengths to the outer hs lengths
|
||||
return outer_i.push_back(quad_output_hs_lengthss[i]);
|
||||
},
|
||||
number<InDstrEncode::NDimX>{});
|
||||
|
||||
static constexpr auto dst_out_hs_lengthss = tuple_reverse(full_out_hs_lengthss);
|
||||
|
||||
// for PS→RHS mapping(both major and minor), we need to modify the last element of the major
|
||||
// sequence
|
||||
static constexpr auto modified_ps_to_rhss_major = generate_tuple(
|
||||
// for PS→RHS mapping(both major and minor), we need to modify the last element (which is for
|
||||
// thread distr) of the major sequence
|
||||
static constexpr auto dst_ps_to_rhss_major = generate_tuple(
|
||||
// for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed
|
||||
[](auto i) {
|
||||
if constexpr(i == input_ps_to_rhss_major.size() - 1)
|
||||
{
|
||||
constexpr auto current_size = input_ps_to_rhss_major[i].size();
|
||||
constexpr auto reduce_size = quad_ps_to_rhss_major[number<0>{}].size();
|
||||
constexpr auto reduce_size = quad_input_ps_to_rhss_major0.size();
|
||||
constexpr auto quad_out = quad_output_ps_to_rhss_major0;
|
||||
constexpr auto reduced_ps_to_rhss_major = input_ps_to_rhss_major[i].extract(
|
||||
typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{});
|
||||
return reduced_ps_to_rhss_major.push_back(number<2>{});
|
||||
return reduced_ps_to_rhss_major.transform(swap_one_and_two).push_back(quad_out);
|
||||
}
|
||||
else
|
||||
{
|
||||
// For all other sequences, keep them unchanged
|
||||
return input_ps_to_rhss_major[i];
|
||||
// For all other sequences (i.e. warp), keep them unchanged
|
||||
return input_ps_to_rhss_major[i].transform(swap_one_and_two);
|
||||
}
|
||||
},
|
||||
number<input_ps_to_rhss_major.size()>{});
|
||||
|
||||
static constexpr auto minor_last_index =
|
||||
full_out_hs_lengthss[number<InDstrEncode::NDimX - 1>{}].size() - 1;
|
||||
static constexpr auto major_last_index = full_out_hs_lengthss[number<0>{}].size() - 1;
|
||||
static constexpr auto quad_idx_offset =
|
||||
transform_tuples([](auto x) { return number<x.size()>{}; }, reversed_outer_hs_lengthss);
|
||||
|
||||
// minus 1 because RsLength is not counted
|
||||
static constexpr auto quad_output_ps_minor_offset = to_sequence(generate_tuple_for(
|
||||
[](auto x) { return quad_idx_offset[number<x - 1>{}]; }, quad_output_ps_to_rhss_major0));
|
||||
static constexpr auto quad_output_ys_minor_offset = to_sequence(generate_tuple_for(
|
||||
[](auto x) { return quad_idx_offset[number<x - 1>{}]; }, quad_output_ys_to_rhs_major));
|
||||
|
||||
static constexpr auto dst_ps_to_rhss_minor = generate_tuple(
|
||||
[](auto i) {
|
||||
constexpr auto input_i = input_ps_to_rhss_minor[i];
|
||||
if constexpr(i == input_ps_to_rhss_minor.size() - 1)
|
||||
{
|
||||
constexpr auto current_size = input_ps_to_rhss_minor[i].size();
|
||||
constexpr auto reduce_size = quad_ps_to_rhss_minor[number<0>{}].size();
|
||||
constexpr auto reduced_ps_to_rhss_minor = input_ps_to_rhss_minor[i].extract(
|
||||
typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{});
|
||||
return reduced_ps_to_rhss_minor.push_back(number<minor_last_index>{});
|
||||
constexpr auto outer_len = input_i.size() - quad_input_ps_to_rhss_minor0.size();
|
||||
constexpr auto outer_ps =
|
||||
typename sequence_split<decltype(input_i), outer_len>::left_type{};
|
||||
|
||||
return outer_ps.push_back(quad_output_ps_minor_offset +
|
||||
quad_output_ps_to_rhss_minor0);
|
||||
}
|
||||
else
|
||||
{
|
||||
// For all other sequences, keep them unchanged
|
||||
return input_ps_to_rhss_minor[i];
|
||||
return input_i;
|
||||
}
|
||||
},
|
||||
number<input_ps_to_rhss_minor.size()>{});
|
||||
|
||||
static constexpr auto outer_input_ys_to_rhs_major = input_ys_to_rhs_major.pop_back();
|
||||
|
||||
// for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed
|
||||
static constexpr auto swap_one_and_two = [](const index_t idx) {
|
||||
return (idx == 1) ? 2 : (idx == 2) ? 1 : idx;
|
||||
};
|
||||
static constexpr auto dst_ps_to_rhss_major = generate_tuple(
|
||||
[](auto i) { return modified_ps_to_rhss_major[i].transform(swap_one_and_two); },
|
||||
number<modified_ps_to_rhss_major.size()>{});
|
||||
static constexpr auto dst_ys_to_rhs_major =
|
||||
outer_input_ys_to_rhs_major.transform(swap_one_and_two).push_back(number<2>{});
|
||||
|
||||
static constexpr auto modified_input_ys_to_rhs_major =
|
||||
input_ys_to_rhs_major.pop_back().push_back(number<1>{});
|
||||
static constexpr auto dst_ys_to_rhs_minor = input_ys_to_rhs_minor.pop_back().push_back(
|
||||
number<(quad_output_ys_minor_offset + quad_output_ys_to_rhs_minor)[I0]>{});
|
||||
|
||||
static constexpr auto dst_ys_to_rhs_major = generate_sequence_v2(
|
||||
[](auto i) { return number<swap_one_and_two(modified_input_ys_to_rhs_major[i])>{}; },
|
||||
number<modified_input_ys_to_rhs_major.size()>{});
|
||||
|
||||
static constexpr auto dst_ys_to_rhs_minor =
|
||||
input_ys_to_rhs_minor.pop_back().push_back(number<major_last_index>{});
|
||||
|
||||
using OutDstrEncode = tile_distribution_encoding<typename InDstrEncode::RsLengths,
|
||||
remove_cvref_t<decltype(dst_out_hs_lengthss)>,
|
||||
remove_cvref_t<decltype(dst_ps_to_rhss_major)>,
|
||||
remove_cvref_t<decltype(dst_ps_to_rhss_minor)>,
|
||||
remove_cvref_t<decltype(dst_ys_to_rhs_major)>,
|
||||
remove_cvref_t<decltype(dst_ys_to_rhs_minor)>>;
|
||||
using TransposedDstrEncode =
|
||||
tile_distribution_encoding<typename InDstrEncode::RsLengths,
|
||||
remove_cvref_t<decltype(dst_out_hs_lengthss)>,
|
||||
remove_cvref_t<decltype(dst_ps_to_rhss_major)>,
|
||||
remove_cvref_t<decltype(dst_ps_to_rhss_minor)>,
|
||||
remove_cvref_t<decltype(dst_ys_to_rhs_major)>,
|
||||
remove_cvref_t<decltype(dst_ys_to_rhs_minor)>>;
|
||||
};
|
||||
|
||||
template <typename TileDistributionEncoding_,
|
||||
typename DataType_,
|
||||
typename Policy = DefaultTranspose<DataType_>>
|
||||
using OutputTileDistributionTraits =
|
||||
TransposeTileDistributionTraits<TileDistributionEncoding_, DataType_, Policy, false>;
|
||||
template <typename TileDistributionEncoding_,
|
||||
typename DataType_,
|
||||
typename Policy = DefaultTranspose<DataType_>>
|
||||
using InputTileDistributionTraits =
|
||||
TransposeTileDistributionTraits<TileDistributionEncoding_, DataType_, Policy, true>;
|
||||
|
||||
template <typename InnerEncode,
|
||||
index_t kLeadIterPerWarp,
|
||||
index_t kSecondIterPerWarp,
|
||||
@@ -321,9 +405,9 @@ load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window)
|
||||
{
|
||||
using OutTileDstrEncode =
|
||||
typename OutputTileDistributionTraits<TileDistribution_,
|
||||
typename BottomTensorView_::DataType>::OutDstrEncode;
|
||||
using OutTileDstrEncode = typename OutputTileDistributionTraits<
|
||||
typename TileDistribution_::DstrEncode,
|
||||
typename BottomTensorView_::DataType>::TransposedDstrEncode;
|
||||
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
|
||||
make_static_tile_distribution(OutTileDstrEncode{}));
|
||||
auto trans_tensor = tile_window.template load_transpose<Policy>();
|
||||
|
||||
156
include/ck_tile/core/utility/debug.hpp
Normal file
156
include/ck_tile/core/utility/debug.hpp
Normal file
@@ -0,0 +1,156 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <stdio.h>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <auto... val>
|
||||
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
|
||||
{
|
||||
}
|
||||
template <typename... type>
|
||||
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
|
||||
{
|
||||
}
|
||||
|
||||
template <char... Xs>
|
||||
struct str_literal
|
||||
{
|
||||
static constexpr const char data[] = {Xs..., '\0'};
|
||||
static constexpr const size_t size = sizeof...(Xs);
|
||||
|
||||
template <char... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(str_literal<Ys...> /*rhs*/) const
|
||||
{
|
||||
return str_literal<Xs..., Ys...>{};
|
||||
}
|
||||
|
||||
template <index_t N, char... Ys>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto duplicate_n(const str_literal<Ys...> sep)
|
||||
{
|
||||
if constexpr(N == 0)
|
||||
return str_literal<>{};
|
||||
else if constexpr(N == 1)
|
||||
return str_literal<Xs...>{};
|
||||
else
|
||||
return duplicate_n<N - 1>(sep) + str_literal<Ys..., Xs...>{};
|
||||
}
|
||||
};
|
||||
|
||||
#define make_str_literal(lit_) \
|
||||
std::apply([](auto... indices) { return str_literal<(lit_)[decltype(indices)::value]...>{}; }, \
|
||||
makeTuple(std::make_index_sequence<constexpr_strlen(lit_)>()))
|
||||
|
||||
template <size_t... Idx>
|
||||
constexpr std::tuple<std::integral_constant<size_t, Idx>...>
|
||||
makeTuple(std::index_sequence<Idx...>) noexcept
|
||||
{
|
||||
return {};
|
||||
}
|
||||
constexpr size_t constexpr_strlen(const char* c)
|
||||
{
|
||||
size_t t = 0;
|
||||
while(*c++)
|
||||
++t;
|
||||
return t;
|
||||
}
|
||||
|
||||
template <typename DataType_, typename StaticTileDistribution_>
|
||||
struct static_distributed_tensor;
|
||||
|
||||
template <typename T_, index_t N_>
|
||||
struct thread_buffer;
|
||||
|
||||
// Usage example: CK_PRINTF<float>{}(tensor);
|
||||
template <typename ConvertTo = void,
|
||||
typename FMT = str_literal<>,
|
||||
typename PREFIX = str_literal<>,
|
||||
typename SUFFIX = str_literal<>>
|
||||
struct CK_PRINTF;
|
||||
template <typename ConvertTo, char... FMTChars, char... PREFIXChars, char... SUFFIXChars>
|
||||
struct CK_PRINTF<ConvertTo,
|
||||
str_literal<FMTChars...>,
|
||||
str_literal<PREFIXChars...>,
|
||||
str_literal<SUFFIXChars...>>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto default_format()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, float>)
|
||||
return make_str_literal("%8.3f");
|
||||
else if constexpr(std::is_same_v<T, int>)
|
||||
return make_str_literal("%5d");
|
||||
else if constexpr(std::is_same_v<T, unsigned int>)
|
||||
return make_str_literal("%5u");
|
||||
else
|
||||
return make_str_literal("0x%08x");
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_prefix()
|
||||
{
|
||||
constexpr auto fmt_tid = make_str_literal("tid %03d: [%02d] ");
|
||||
if constexpr(sizeof...(PREFIXChars) == 0)
|
||||
return fmt_tid;
|
||||
else
|
||||
return fmt_tid + make_str_literal(" ") + str_literal<PREFIXChars...>{};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_suffix()
|
||||
{
|
||||
constexpr auto lf = make_str_literal("\n");
|
||||
if constexpr(sizeof...(SUFFIXChars) == 0)
|
||||
return lf;
|
||||
else
|
||||
return str_literal<SUFFIXChars...>{} + lf;
|
||||
}
|
||||
|
||||
template <typename T, index_t N, typename Y, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE void impl(const thread_buffer<T, N>& buf,
|
||||
std::integer_sequence<index_t, Is...>) const
|
||||
{
|
||||
using FMT1 = std::conditional_t<sizeof...(FMTChars) == 0,
|
||||
decltype(default_format<Y>()),
|
||||
str_literal<FMTChars...>>;
|
||||
constexpr auto fmt_v = FMT1::template duplicate_n<N>(make_str_literal(" "));
|
||||
constexpr auto fmt_wrap_v = get_prefix() + fmt_v + get_suffix();
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wformat-nonliteral"
|
||||
printf(fmt_wrap_v.data, get_thread_id(), N, type_convert<Y>(buf[Is])...);
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_HOST_DEVICE void operator()(const thread_buffer<T, N>& buf) const
|
||||
{
|
||||
using ConvertTo_ = std::conditional_t<std::is_same_v<ConvertTo, void>, T, ConvertTo>;
|
||||
impl<T, N, ConvertTo_>(buf, std::make_integer_sequence<index_t, N>{});
|
||||
}
|
||||
|
||||
template <typename... TS>
|
||||
CK_TILE_HOST_DEVICE void operator()(const static_distributed_tensor<TS...>& tensor) const
|
||||
{
|
||||
return operator()(tensor.get_thread_buffer());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ConvertTo = void,
|
||||
typename FMT = str_literal<>,
|
||||
typename PREFIX = str_literal<>,
|
||||
typename SUFFIX = str_literal<>>
|
||||
struct CK_PRINTF_WARP0 : public CK_PRINTF<ConvertTo, FMT, PREFIX, SUFFIX>
|
||||
{
|
||||
using base_t = CK_PRINTF<ConvertTo, FMT, PREFIX, SUFFIX>;
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(const T& buf) const
|
||||
{
|
||||
if(get_thread_id() < get_warp_size())
|
||||
base_t::operator()(buf);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user