Merge remote-tracking branch 'origin/ginolu/add_wgmfma_dispatcher' into mtgu/cktile_mxfp4_flatmm_dev

This commit is contained in:
mtgu0705
2025-09-08 22:09:15 -05:00
1276 changed files with 113756 additions and 18739 deletions

View File

@@ -9,6 +9,7 @@
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/print.hpp"
namespace ck_tile {
@@ -139,20 +140,19 @@ struct pass_through : public base_transform<1, 1>
{
return make_tuple(low_vector_lengths, low_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("pass_through{");
//
printf("up_lengths_:");
print(up_lengths_);
//
printf("}");
}
};
template <typename LowLength>
CK_TILE_HOST_DEVICE static void print(const pass_through<LowLength>& pt)
{
printf("pass_through{");
printf("up_lengths_: ");
print(pt.get_upper_lengths());
printf("}");
}
template <typename LowLength,
typename LeftPadLength,
typename RightPadLength,
@@ -229,29 +229,25 @@ struct pad : public base_transform<1, 1>
ck_tile::is_known_at_compile_time<LeftPadLength>::value &&
ck_tile::is_known_at_compile_time<RightPadLength>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("left_pad_length_: ");
print(left_pad_length_);
printf(", ");
//
printf("right_pad_length_: ");
print(right_pad_length_);
printf("}");
}
};
template <typename LowLength,
typename LeftPadLength,
typename RightPadLength,
bool SkipIsValidCheck>
CK_TILE_HOST_DEVICE static void
print(const pad<LowLength, LeftPadLength, RightPadLength, SkipIsValidCheck>& p)
{
printf("pad{");
printf("up_lengths_: ");
print(p.up_lengths_);
printf(", left_pad_length_: ");
print(p.left_pad_length_);
printf(", right_pad_length_: ");
print(p.right_pad_length_);
printf("}");
}
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
struct left_pad
{
@@ -330,24 +326,20 @@ struct left_pad
// It's up to runtime to check the padding length should be multiple of vector length
return make_tuple(low_vector_lengths, low_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("left_pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("left_pad_length_: ");
print(left_pad_length_);
printf("}");
}
};
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck>
CK_TILE_HOST_DEVICE static void
print(const left_pad<LowLength, LeftPadLength, SkipIsValidCheck>& lp)
{
printf("left_pad{");
printf("up_lengths_: ");
print(lp.up_lengths_);
printf(", left_pad_length_: ");
print(lp.left_pad_length_);
printf("}");
}
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
struct right_pad : public base_transform<1, 1>
{
@@ -430,24 +422,20 @@ struct right_pad : public base_transform<1, 1>
// It's up to runtime to check the padding length should be multiple of vector length
return make_tuple(low_vector_lengths, low_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("right_pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("right_pad_length_: ");
print(right_pad_length_);
printf("}");
}
};
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck>
CK_TILE_HOST_DEVICE static void
print(const right_pad<LowLength, RightPadLength, SkipIsValidCheck>& rp)
{
printf("right_pad{");
printf("up_lengths_: ");
print(rp.up_lengths_);
printf(", right_pad_length_: ");
print(rp.right_pad_length_);
printf("}");
}
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
// UpLengths and Coefficients can be either of the followings:
// 1) Tuple of index_t, which is known at run-time, or
@@ -532,24 +520,19 @@ struct embed : public base_transform<1, UpLengths::size()>
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<Coefficients>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("embed{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("coefficients_: ");
print(coefficients_);
printf("}");
}
};
template <typename UpLengths, typename Coefficients>
CK_TILE_HOST_DEVICE static void print(const embed<UpLengths, Coefficients>& e)
{
printf("embed{");
printf("up_lengths_: ");
print(e.up_lengths_);
printf(", coefficients_: ");
print(e.coefficients_);
printf("}");
}
template <typename LowLengths>
struct lambda_merge_generate_MagicDivision_calculate_magic_divisor
{
@@ -699,24 +682,19 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("merge_v2_magic_division{");
//
printf("low_lengths_ ");
print(low_lengths_);
printf(", ");
//
printf("up_lengths_ ");
print(up_lengths_);
printf("}");
}
};
template <typename LowLengths>
CK_TILE_HOST_DEVICE static void print(const merge_v2_magic_division<LowLengths>& m)
{
printf("merge_v2_magic_division{");
printf("low_lengths_: ");
print(m.low_lengths_);
printf(", up_lengths_: ");
print(m.up_lengths_);
printf("}");
}
// Implementation of "merge" transformation primitive that uses division and mod. It is supposed to
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
// will be very bad
@@ -830,29 +808,21 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("Merge_v3_direct_division_mod{");
//
printf("low_lengths_ ");
print(low_lengths_);
printf(", ");
//
printf("low_lengths_scan_ ");
print(low_lengths_scan_);
printf(", ");
//
printf("up_lengths_ ");
print(up_lengths_);
printf("}");
}
};
template <typename LowLengths>
CK_TILE_HOST_DEVICE static void print(const merge_v3_division_mod<LowLengths>& m)
{
printf("merge_v3_division_mod{");
printf("low_lengths_: ");
print(m.low_lengths_);
printf(", low_lengths_scan_: ");
print(m.low_lengths_scan_);
printf(", up_lengths_: ");
print(m.up_lengths_);
printf("}");
}
template <typename UpLengths, bool Use24BitIntegerCalculation>
struct unmerge : public base_transform<1, UpLengths::size()>
{
@@ -958,24 +928,19 @@ struct unmerge : public base_transform<1, UpLengths::size()>
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("unmerge{");
//
printf("up_lengths_");
print(up_lengths_);
printf(", ");
//
printf("up_lengths_scan_");
print(up_lengths_scan_);
printf("}");
}
};
template <typename UpLengths, bool Use24BitIntegerCalculation>
CK_TILE_HOST_DEVICE static void print(const unmerge<UpLengths, Use24BitIntegerCalculation>& u)
{
printf("unmerge{");
printf("up_lengths_: ");
print(u.up_lengths_);
printf(", up_lengths_scan_: ");
print(u.up_lengths_scan_);
printf("}");
}
template <typename LowerIndex>
struct freeze : public base_transform<1, 0>
{
@@ -1023,19 +988,17 @@ struct freeze : public base_transform<1, 0>
{
return ck_tile::is_known_at_compile_time<LowerIndex>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("freeze{");
//
printf("low_idx_: ");
print(low_idx_);
printf("}");
}
};
template <typename LowerIndex>
CK_TILE_HOST_DEVICE static void print(const freeze<LowerIndex>& f)
{
printf("freeze{");
printf("low_idx_: ");
print(f.low_idx_);
printf("}");
}
// insert a dangling upper dimension without lower dimension
template <typename UpperLength>
struct insert : public base_transform<0, 1>
@@ -1092,18 +1055,17 @@ struct insert : public base_transform<0, 1>
{
return ck_tile::is_known_at_compile_time<UpperLength>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("insert{");
//
print(up_lengths_);
printf("}");
}
};
template <typename UpperLength>
CK_TILE_HOST_DEVICE static void print(const insert<UpperLength>& i)
{
printf("insert{");
printf("up_lengths_: ");
print(i.up_lengths_);
printf("}");
}
// replicate the original tensor and create a higher dimensional tensor
template <typename UpLengths>
struct replicate : public base_transform<0, UpLengths::size()>
@@ -1152,21 +1114,19 @@ struct replicate : public base_transform<0, UpLengths::size()>
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("replicate{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf("}");
}
//
UpLengths up_lengths_;
};
template <typename UpLengths>
CK_TILE_HOST_DEVICE static void print(const replicate<UpLengths>& r)
{
printf("replicate{");
printf("up_lengths_: ");
print(r.up_lengths_);
printf("}");
}
template <typename LowLength, typename SliceBegin, typename SliceEnd>
struct slice : public base_transform<1, 1>
{
@@ -1238,28 +1198,20 @@ struct slice : public base_transform<1, 1>
ck_tile::is_known_at_compile_time<SliceBegin>::value &&
ck_tile::is_known_at_compile_time<SliceEnd>::value;
}
};
CK_TILE_HOST_DEVICE void print() const
{
printf("slice{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("slice_begin_: ");
print(slice_begin_);
printf(", ");
//
printf("slice_end_: ");
print(slice_end_);
printf("}");
} // namespace ck
}; // namespace ck
template <typename LowLength, typename SliceBegin, typename SliceEnd>
CK_TILE_HOST_DEVICE static void print(const slice<LowLength, SliceBegin, SliceEnd>& s)
{
printf("slice{");
printf("up_lengths_: ");
print(s.up_lengths_);
printf(", slice_begin_: ");
print(s.slice_begin_);
printf(", slice_end_: ");
print(s.slice_end_);
printf("}");
}
/*
* \brief lower_idx = upper_idx % modulus.
@@ -1328,19 +1280,19 @@ struct modulo : public base_transform<1, 1>
{
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("Modulus{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf("}");
}
};
template <typename Modulus, typename UpLength>
CK_TILE_HOST_DEVICE static void print(const modulo<Modulus, UpLength>& m)
{
printf("modulo{");
printf("modulus_: ");
print(m.modulus_);
printf(", up_lengths_: ");
print(m.up_lengths_);
printf("}");
}
// 2D XOR, NOTE: "xor" is a keyword
template <typename LowLengths>
struct xor_t : public base_transform<2, 2>
@@ -1424,20 +1376,17 @@ struct xor_t : public base_transform<2, 2>
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("xor_t{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
printf("}");
}
};
template <typename LowLengths>
CK_TILE_HOST_DEVICE static void print(const xor_t<LowLengths>& x)
{
printf("xor_t{");
printf("up_lengths_: ");
print(x.up_lengths_);
printf("}");
}
template <typename LowLength, typename OffsetLength>
struct offset : public base_transform<1, 1>
{
@@ -1509,24 +1458,19 @@ struct offset : public base_transform<1, 1>
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<OffsetLength>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("offset{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("offset_length_: ");
print(offset_length_);
printf("}");
}
};
template <typename LowLength, typename OffsetLength>
CK_TILE_HOST_DEVICE static void print(const offset<LowLength, OffsetLength>& o)
{
printf("offset{");
printf("up_lengths_: ");
print(o.up_lengths_);
printf(", offset_length_: ");
print(o.offset_length_);
printf("}");
}
template <typename UpLength, typename IndexingAdaptor>
struct indexing : public base_transform<1, 1>
{
@@ -1595,20 +1539,19 @@ struct indexing : public base_transform<1, 1>
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
IndexingAdaptor::is_known_at_compile_time();
}
CK_TILE_HOST_DEVICE void print() const
{
printf("embed{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
printf("}");
}
};
template <typename UpLength, typename IndexingAdaptor>
CK_TILE_HOST_DEVICE static void print(const indexing<UpLength, IndexingAdaptor>& i)
{
printf("indexing{");
printf("up_lengths_: ");
print(i.up_lengths_);
printf(", iadaptor_: ");
print(i.iadaptor_);
printf("}");
}
//*******************************************************************************************************
template <typename LowLength>

View File

@@ -100,10 +100,8 @@ struct space_filling_curve
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
// idim-th element of multidimensional index.
// All constexpr variables have to be captured by VALUE.
constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr
{
constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr
{
constexpr auto compute_index = [idx_1d, access_strides](auto idim) constexpr {
constexpr auto compute_index_impl = [idx_1d, access_strides](auto jdim) constexpr {
auto res = idx_1d.value;
auto id = 0;

View File

@@ -77,6 +77,7 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/utility/print.hpp"
namespace ck_tile {
@@ -103,7 +104,7 @@ enum struct tile_distribution_pattern
block_raked,
};
struct TileDistributionEncodingPattern
struct tile_distribution_encoding_pattern
{
};
@@ -125,7 +126,7 @@ template <index_t BlockSize,
index_t VecSize,
tile_distribution_pattern DistributionPattern,
index_t NumWaveGroups = 1>
struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern
struct tile_distribution_encoding_pattern_2d : public tile_distribution_encoding_pattern
{
};
@@ -135,12 +136,13 @@ template <index_t BlockSize,
index_t XPerTile,
index_t VecSize,
index_t NumWaveGroups>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::thread_raked,
NumWaveGroups> : public TileDistributionEncodingPattern
struct tile_distribution_encoding_pattern_2d<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::thread_raked,
NumWaveGroups>
: public tile_distribution_encoding_pattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
@@ -164,7 +166,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
"X0 * warp_ys * Y0 must cover whole workgroup!");
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
if constexpr(NumWaveGroups != 1)
{
@@ -188,7 +190,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
}
}
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto make_shuffled_2d_static_tile_distribution()
{
if constexpr(NumWaveGroups != 1)
{
@@ -219,12 +221,13 @@ template <index_t BlockSize,
index_t XPerTile,
index_t VecSize,
index_t NumWaveGroups>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::warp_raked,
NumWaveGroups> : public TileDistributionEncodingPattern
struct tile_distribution_encoding_pattern_2d<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::warp_raked,
NumWaveGroups>
: public tile_distribution_encoding_pattern
{
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
@@ -243,7 +246,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
static constexpr index_t Y1 = YPerTile / (Y2 * Y0); // # of iters within wavefront
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
@@ -254,7 +257,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
sequence<1, 1>>{}); // -> <Y1, X1>
}
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto make_shuffled_2d_static_tile_distribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
@@ -272,12 +275,13 @@ template <index_t BlockSize,
index_t XPerTile,
index_t VecSize,
index_t NumWaveGroups>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::block_raked,
NumWaveGroups> : public TileDistributionEncodingPattern
struct tile_distribution_encoding_pattern_2d<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::block_raked,
NumWaveGroups>
: public tile_distribution_encoding_pattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
@@ -294,7 +298,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
static constexpr index_t Y0 = YPerTile / (Y2 * Y1); // # of iters
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
@@ -305,7 +309,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
sequence<0, 1>>{}); // -> <Y0, X1>
}
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto make_shuffled_2d_static_tile_distribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
@@ -317,4 +321,51 @@ struct TileDistributionEncodingPattern2D<BlockSize,
}
};
// Helper function to convert enum to string
constexpr const char* tile_distribution_pattern_to_string(tile_distribution_pattern pattern)
{
switch(pattern)
{
case tile_distribution_pattern::thread_raked: return "thread_raked";
case tile_distribution_pattern::warp_raked: return "warp_raked";
case tile_distribution_pattern::block_raked: return "block_raked";
default: return "unknown";
}
}
template <index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize,
tile_distribution_pattern DistributionPattern,
index_t NumWaveGroups>
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding_pattern_2d<BlockSize,
YPerTile,
XPerTile,
VecSize,
DistributionPattern,
NumWaveGroups>&)
{
using PatternType = tile_distribution_encoding_pattern_2d<BlockSize,
YPerTile,
XPerTile,
VecSize,
DistributionPattern,
NumWaveGroups>;
printf("tile_distribution_encoding_pattern_2d<BlockSize:%d, YPerTile:%d, XPerTile:%d, "
"VecSize:%d, %s>: ",
BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern_to_string(DistributionPattern));
printf("{<Y0, Y1, Y2>: <%d, %d, %d>, <X0, X1>: <%d, %d>}\n",
PatternType::Y0,
PatternType::Y1,
PatternType::Y2,
PatternType::X0,
PatternType::X1);
}
} // namespace ck_tile

View File

@@ -41,10 +41,6 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
{
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
// r.x = __builtin_amdgcn_readfirstlane(r.x);
// r.y = __builtin_amdgcn_readfirstlane(r.y);
// r.z = __builtin_amdgcn_readfirstlane(r.z);
// r.w = __builtin_amdgcn_readfirstlane(r.w);
return r;
}
@@ -1288,26 +1284,46 @@ llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
template <bool pre_nop = false>
CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
int32x4_t rsrc,
index_t voffset,
index_t /*soffset*/,
index_t ioffset /*max 0xFFF*/,
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
template <unsigned num_dwords, bool pre_nop = false>
CK_TILE_DEVICE void async_buffer_load_dwordxn_v(void* smem,
int32x4_t rsrc,
index_t voffset,
index_t /*soffset*/,
index_t ioffset /*max 0xFFF*/,
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_dword %1, %2, 0 offen offset:%3 lds"
: "=r"(smem) /*dummy dependency for smem*/
: "v"(voffset), "s"(rsrc), "n"(ioffset)
#define CK_TILE_ASYNC_LOAD_WITH_INSTR(instr) \
if constexpr(pre_nop) \
asm volatile("s_nop 4\n" instr " %1, %2, 0 offen offset:%3 lds" \
: "=r"(smem) /*dummy dependency for smem*/ \
: "v"(voffset), "s"(rsrc), "n"(ioffset) \
: "memory"); \
else \
asm volatile(instr " %1, %2, 0 offen offset:%3 lds" \
: "=r"(smem) /*dummy dependency for smem*/ \
: "v"(voffset), "s"(rsrc), "n"(ioffset) \
: "memory");
if constexpr(num_dwords == 1)
{
CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dword");
}
#if defined(__gfx950__)
else if constexpr(num_dwords == 3)
{
CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx3");
}
else if constexpr(num_dwords == 4)
{
CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx4");
}
#endif
else
asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds"
: "=r"(smem) /*dummy dependency for smem*/
: "v"(voffset), "s"(rsrc), "n"(ioffset)
: "memory");
{
static_assert(false, "wrong! not implemented data width");
}
#undef CK_TILE_ASYNC_LOAD_WITH_INSTR
}
CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
@@ -1326,6 +1342,17 @@ enum struct amd_buffer_coherence_enum
glc = 1,
slc = 2,
glc_slc = 3,
// gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
// SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
// NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
WAVE_NT0 = 0,
WAVE_NT1 = 2,
GROUP_NT0 = 1,
GROUP_NT1 = 3,
DEVICE_NT0 = 8,
DEVICE_NT1 = 10,
SYSTEM_NT0 = 9,
SYSTEM_NT1 = 11,
};
template <index_t N,
@@ -1770,15 +1797,18 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem,
index_t src_immediate_addr_offset = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
constexpr index_t num_bytes = sizeof(T) * N;
constexpr index_t num_words = num_bytes / 4;
static_assert(num_bytes % 4 == 0 && (num_words == 1 || num_words == 3 || num_words == 4),
"wrong! only support in dword, dwordx3, dwordx4");
async_buffer_load_dword_v(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
0,
bool_constant<pre_nop>{});
async_buffer_load_dwordxn_v<num_words>(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
0,
bool_constant<pre_nop>{});
}
template <typename T,
@@ -1794,60 +1824,37 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
bool_constant<oob_conditional_check> = {})
{
constexpr index_t bytes = sizeof(T) * N;
// Used to catch the cases when src_immediate_addr_offset is NOT 0.
// Remove this assert once other sizes are implemented.
assert(src_immediate_addr_offset == 0 &&
"wrong! not implemented src_immediate_addr_offset size, only 0 supported");
ignore = src_immediate_addr_offset;
#if defined(__gfx950__)
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
"wrong! only support in dword, dwordx3, dwordx4");
ignore = src_wave_addr_offset;
ignore = src_immediate_addr_offset;
if constexpr(oob_conditional_check)
{
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
v_offset,
0,
0,
static_cast<index_t>(coherence));
}
else
{
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
src_thread_addr_offset,
0,
0,
static_cast<index_t>(coherence));
}
src_wave_addr_offset = 0;
#else
static_assert(bytes == 4, "wrong! not implemented vector size");
if constexpr(oob_conditional_check)
{
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
v_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
else
{
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
#endif
// Set up v_offset:
index_t v_offset = src_thread_addr_offset;
if constexpr(oob_conditional_check)
v_offset = flag ? v_offset : src_wave_buffer_resource[2];
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
// Use C-style cast to change address space without dropping llvm noalias attribute
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
(as3_uint32_ptr)(smem),
bytes,
v_offset,
src_wave_addr_offset,
/*src_immediate_addr_offset*/ 0,
static_cast<index_t>(coherence));
#pragma clang diagnostic pop
}
template <index_t N,
@@ -2826,78 +2833,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
#endif
}
template <typename T, index_t NumElemsPerThread>
CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t global_offset,
T* lds_base_ptr,
const index_t lds_offset,
const bool is_valid,
const index_t src_element_space_size)
{
// Direct loads require that each thread reads and writes exactly a single DWORD.
constexpr auto dword_bytes = 4;
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
static_assert(bytes_per_thread == dword_bytes);
const uint32_t* global_ptr =
reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(global_base_ptr));
const int32x4_t src_resource =
make_wave_buffer_resource(global_ptr, src_element_space_size * sizeof(T));
const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000;
#if CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
T* lds_ptr = lds_base_ptr + lds_offset;
auto const lds_ptr_sgpr =
__builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(lds_ptr)));
asm volatile("s_mov_b32 m0, %0; \n\t"
"buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
"v"(global_offset_bytes),
"s"(src_resource)
: "memory");
#else
// LDS pointer must be attributed with the LDS address space.
as3_uint32_ptr lds_ptr =
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
#endif
}
#if defined(__gfx950__)
template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
{
#define __LDS_ADDR __attribute__((address_space(3)))
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
"We need to have the compatible compiler version to build this instruction");
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
// Use C-style cast to change address space without dropping llvm noalias attribute
const auto in_ptr_ = (__LDS_ADDR T*)(const_cast<T*>(in_ptr));
#pragma clang diagnostic pop
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
{
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_fp16x4_t*>(in_ptr_);
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
{
__attribute__((address_space(3))) cktile_llvm_bf16x4_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) cktile_llvm_bf16x4_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_bf16x4_t*>(in_ptr_);
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t> ||
std::is_same_v<remove_cvref_t<T>, ck_tile::bf8_t> ||
std::is_same_v<remove_cvref_t<T>, ck_tile::int8_t>)
{
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t;
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_i32x2_t*>(in_ptr_);
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
}
else
{
static_assert(false, "not implemented");
}
#undef __LDS_ADDR
}
#endif

View File

@@ -32,10 +32,6 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
{
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
r.x = __builtin_amdgcn_readfirstlane(r.x);
r.y = __builtin_amdgcn_readfirstlane(r.y);
r.z = __builtin_amdgcn_readfirstlane(r.z);
r.w = __builtin_amdgcn_readfirstlane(r.w);
return r;
}
@@ -1148,26 +1144,46 @@ llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
template <bool pre_nop = false>
CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
int32x4_t rsrc,
index_t voffset,
index_t /*soffset*/,
index_t ioffset /*max 0xFFF*/,
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
template <unsigned num_dwords, bool pre_nop = false>
CK_TILE_DEVICE void async_buffer_load_dwordxn_v(void* smem,
int32x4_t rsrc,
index_t voffset,
index_t /*soffset*/,
index_t ioffset /*max 0xFFF*/,
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_dword %1, %2, 0 offen offset:%3 lds"
: "=r"(smem) /*dummy dependency for smem*/
: "v"(voffset), "s"(rsrc), "n"(ioffset)
#define CK_TILE_ASYNC_LOAD_WITH_INSTR(instr) \
if constexpr(pre_nop) \
asm volatile("s_nop 4\n" instr " %1, %2, 0 offen offset:%3 lds" \
: "=r"(smem) /*dummy dependency for smem*/ \
: "v"(voffset), "s"(rsrc), "n"(ioffset) \
: "memory"); \
else \
asm volatile(instr " %1, %2, 0 offen offset:%3 lds" \
: "=r"(smem) /*dummy dependency for smem*/ \
: "v"(voffset), "s"(rsrc), "n"(ioffset) \
: "memory");
if constexpr(num_dwords == 1)
{
CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dword");
}
#if defined(__gfx950__)
else if constexpr(num_dwords == 3)
{
CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx3");
}
else if constexpr(num_dwords == 4)
{
CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx4");
}
#endif
else
asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds"
: "=r"(smem) /*dummy dependency for smem*/
: "v"(voffset), "s"(rsrc), "n"(ioffset)
: "memory");
{
static_assert(false, "wrong! not implemented data width");
}
#undef CK_TILE_ASYNC_LOAD_WITH_INSTR
}
CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
@@ -1186,6 +1202,17 @@ enum struct amd_buffer_coherence_enum
glc = 1,
slc = 2,
glc_slc = 3,
// gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
// SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
// NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
WAVE_NT0 = 0,
WAVE_NT1 = 2,
GROUP_NT0 = 1,
GROUP_NT1 = 3,
DEVICE_NT0 = 8,
DEVICE_NT1 = 10,
SYSTEM_NT0 = 9,
SYSTEM_NT1 = 11,
};
template <index_t N,
@@ -1534,15 +1561,18 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
index_t src_immediate_addr_offset = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
constexpr index_t num_bytes = sizeof(T) * N;
constexpr index_t num_words = num_bytes / 4;
static_assert(num_bytes % 4 == 0 && (num_words == 1 || num_words == 3 || num_words == 4),
"wrong! only support in dword, dwordx3, dwordx4");
async_buffer_load_dword_v(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
0,
bool_constant<pre_nop>{});
async_buffer_load_dwordxn_v<num_words>(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
0,
bool_constant<pre_nop>{});
}
template <typename T,
@@ -1558,60 +1588,37 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
bool_constant<oob_conditional_check> = {})
{
constexpr index_t bytes = sizeof(T) * N;
// Used to catch the cases when src_immediate_addr_offset is NOT 0.
// Remove this assert once other sizes are implemented.
assert(src_immediate_addr_offset == 0 &&
"wrong! not implemented src_immediate_addr_offset size, only 0 supported");
ignore = src_immediate_addr_offset;
#if defined(__gfx950__)
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
"wrong! only support in dword, dwordx3, dwordx4");
ignore = src_wave_addr_offset;
ignore = src_immediate_addr_offset;
if constexpr(oob_conditional_check)
{
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
v_offset,
0,
0,
static_cast<index_t>(coherence));
}
else
{
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
src_thread_addr_offset,
0,
0,
static_cast<index_t>(coherence));
}
src_wave_addr_offset = 0;
#else
static_assert(bytes == 4, "wrong! not implemented vector size");
if constexpr(oob_conditional_check)
{
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
v_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
else
{
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
#endif
// Set up v_offset:
index_t v_offset = src_thread_addr_offset;
if constexpr(oob_conditional_check)
v_offset = flag ? v_offset : src_wave_buffer_resource[2];
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
// Use C-style cast to change address space without dropping llvm noalias attribute
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
(as3_uint32_ptr)(smem),
bytes,
v_offset,
src_wave_addr_offset,
/*src_immediate_addr_offset*/ 0,
static_cast<index_t>(coherence));
#pragma clang diagnostic pop
}
template <index_t N,
@@ -2598,11 +2605,6 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
const bool is_valid,
const index_t src_element_space_size)
{
// Direct loads require that each thread reads and writes exactly a single DWORD.
constexpr auto dword_bytes = 4;
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
static_assert(bytes_per_thread == dword_bytes);
const uint32_t* global_ptr =
reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(global_base_ptr));
const int32x4_t src_resource =
@@ -2619,49 +2621,69 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
"s"(src_resource)
: "memory");
#else
// Direct loads require that each thread reads and writes exactly a single DWORD.
#if defined(__gfx9__)
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
#endif
// Direct loads require that each thread reads and writes a multiple of DWORDs (4 bytes).
// For gfx950: supports 1, 3, or 4 DWORDs per thread
// For gfx942: supports exactly 1 DWORD per thread
#if defined(__gfx950__)
constexpr auto dword_bytes = 4;
static_assert(bytes_per_thread == dword_bytes || bytes_per_thread == dword_bytes * 3 ||
bytes_per_thread == dword_bytes * 4);
#elif defined(__gfx9__)
constexpr auto dword_bytes = 4;
static_assert(bytes_per_thread == dword_bytes);
#endif
// LDS pointer must be attributed with the LDS address space.
as3_uint32_ptr lds_ptr =
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0);
#endif
}
#if defined(__gfx950__)
template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
{
#define __LDS_ADDR __attribute__((address_space(3)))
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
"We need to have the compatible compiler version to build this instruction");
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
// Use C-style cast to change address space without dropping llvm noalias attribute
const auto in_ptr_ = (__LDS_ADDR T*)(const_cast<T*>(in_ptr));
#pragma clang diagnostic pop
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
{
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_fp16x4_t*>(in_ptr_);
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
{
__attribute__((address_space(3))) cktile_llvm_bf16x4_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) cktile_llvm_bf16x4_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_bf16x4_t*>(in_ptr_);
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t> ||
std::is_same_v<remove_cvref_t<T>, ck_tile::bf8_t> ||
std::is_same_v<remove_cvref_t<T>, ck_tile::int8_t>)
{
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t;
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_i32x2_t*>(in_ptr_);
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
}
else
{
static_assert(false, "not implemented");
}
#undef __LDS_ADDR
}
#endif

View File

@@ -10,53 +10,55 @@
namespace ck_tile {
// this generate wave level tile distribution
template <typename T, typename = void>
template <typename T, index_t LaneGroupSize = 16, typename = void>
struct LaneGroupTransposeTraits;
template <typename T>
struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 2>>
template <typename T, index_t LaneGroupSize>
struct LaneGroupTransposeTraits<T, LaneGroupSize, std::enable_if_t<sizeof(T) == 2>>
{
static_assert(LaneGroupSize == 16 || LaneGroupSize == 32 || LaneGroupSize == 64,
"LaneGroupSize must be 16, 32, or 64");
// before transpose, 4x16
static constexpr index_t ksecondDim = 4;
static constexpr index_t kleadDim = 16;
static constexpr index_t kleadDim = LaneGroupSize;
// after transpose, 16x4
static constexpr index_t ksecondDimT = 16;
static constexpr index_t ksecondDimT = LaneGroupSize;
static constexpr index_t kleadDimT = 4;
template <index_t kOuterDistDim0,
index_t kOuterDistDim1,
index_t kInnerDistDim0,
index_t kInnerDistDim1>
using TileDistribution =
tile_distribution_encoding<sequence<>,
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 4>,
sequence<kInnerDistDim0, kInnerDistDim1, 4, 4>>,
tuple<sequence<1, 2, 1, 2>>,
tuple<sequence<0, 0, 2, 2>>,
sequence<2, 1, 2>,
sequence<1, 1, 3>>;
using TileDistribution = tile_distribution_encoding<
sequence<>,
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 4>,
sequence<kInnerDistDim0, kInnerDistDim1, LaneGroupSize / 16, 4, 4>>,
tuple<sequence<1, 2, 2, 1, 2>>,
tuple<sequence<0, 0, 2, 2, 3>>,
sequence<2, 1, 2>,
sequence<1, 1, 4>>;
};
template <typename T>
struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 1>>
template <typename T, index_t LaneGroupSize>
struct LaneGroupTransposeTraits<T, LaneGroupSize, std::enable_if_t<sizeof(T) == 1>>
{
static constexpr index_t ksecondDim = 8;
static constexpr index_t kleadDim = 16;
static constexpr index_t kleadDim = LaneGroupSize;
static constexpr index_t ksecondDimT = 16;
static constexpr index_t ksecondDimT = LaneGroupSize;
static constexpr index_t kleadDimT = 8;
template <index_t kOuterDistDim0,
index_t kOuterDistDim1,
index_t kInnerDistDim0,
index_t kInnerDistDim1>
using TileDistribution =
tile_distribution_encoding<sequence<>,
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 8>,
sequence<kInnerDistDim0, kInnerDistDim1, 2, 8>>,
tuple<sequence<1, 2, 1, 2>>,
tuple<sequence<0, 0, 2, 2>>,
sequence<2, 1, 2>,
sequence<1, 1, 3>>;
using TileDistribution = tile_distribution_encoding<
sequence<>,
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 8>,
sequence<kInnerDistDim0, kInnerDistDim1, LaneGroupSize / 16, 2, 8>>,
tuple<sequence<1, 2, 2, 1, 2>>,
tuple<sequence<0, 0, 2, 2, 3>>,
sequence<2, 1, 2>,
sequence<1, 1, 4>>;
};
/*
@@ -72,15 +74,15 @@ struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 1>>
* consecutive.
*/
template <typename T,
index_t LaneGroupSize,
index_t kOuterDistDim0,
index_t kOuterDistDim1,
index_t kInnerDistDim0,
index_t kInnerDistDim1>
CK_TILE_DEVICE constexpr auto make_transposed_distr_encode()
{
using xdllevel_dstr_encoding = typename LaneGroupTransposeTraits<T>::
template TileDistribution<kOuterDistDim0, kOuterDistDim1, kInnerDistDim0, kInnerDistDim1>;
return xdllevel_dstr_encoding{};
return typename LaneGroupTransposeTraits<T, LaneGroupSize>::
template TileDistribution<kOuterDistDim0, kOuterDistDim1, kInnerDistDim0, kInnerDistDim1>{};
}
} // namespace ck_tile

View File

@@ -9,6 +9,16 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111
#define CK_TILE_VMCNT(cnt) \
([]() { static_assert(!((cnt) >> 6), "VMCNT only has 6 bits"); }(), \
((cnt) & 0b1111) | (((cnt) & 0b110000) << 10))
#define CK_TILE_EXPCNT(cnt) \
([]() { static_assert(!((cnt) >> 3), "EXP only has 3 bits"); }(), ((cnt) << 4))
#define CK_TILE_LGKMCNT(cnt) \
([]() { static_assert(!((cnt) >> 4), "LGKM only has 4 bits"); }(), ((cnt) << 8))
namespace ck_tile {
@@ -57,6 +67,23 @@ CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
#endif
}
CK_TILE_HOST bool is_wave32()
{
hipDeviceProp_t props{};
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
{
return false;
}
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
{
return false;
}
return props.major > 9;
}
CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; }
CK_TILE_DEVICE index_t get_block_size() { return blockDim.x; }
@@ -71,30 +98,24 @@ CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; }
// Use these instead
CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); }
CK_TILE_DEVICE index_t get_warp_id()
template <bool ReturnSgpr = true>
CK_TILE_DEVICE index_t get_warp_id(bool_constant<ReturnSgpr> = {})
{
return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
const index_t warp_id = threadIdx.x / get_warp_size();
if constexpr(ReturnSgpr)
{
return __builtin_amdgcn_readfirstlane(warp_id);
}
else
{
return warp_id;
}
}
CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
CK_TILE_DEVICE void block_sync_lds()
{
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
// asm volatile("\
// s_waitcnt lgkmcnt(0) \n \
// s_barrier \
// " ::);
__builtin_amdgcn_s_waitcnt(0xc07f);
__builtin_amdgcn_s_barrier();
#else
__syncthreads();
#endif
}
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
{
#ifdef __gfx12__
@@ -113,13 +134,68 @@ CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
#endif
}
// https://llvm.org/docs/AMDGPU/gfx9_waitcnt.html
struct waitcnt_arg
{
// bit numbers (hex) -------------------------> FE'DC'BA98'7'654'3210
// [V]M [E]XP [L]GKM counters and [U]NUSED ---> VV'UU'LLLL'U'EEE'VVVV
CK_TILE_DEVICE static constexpr index_t MAX = 0b11'00'1111'0'111'1111;
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111;
CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111;
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b1111;
template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_vmcnt()
{
static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]");
return MAX & ((cnt & 0b1111) | ((cnt & 0b110000) << 10));
}
template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_expcnt()
{
static_assert(cnt >= 0 && !(cnt >> 3), "valid range is [0..7]");
return MAX & (cnt << 4);
}
template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_lgkmcnt()
{
static_assert(cnt >= 0 && !(cnt >> 4), "valid range is [0..15]");
return MAX & (cnt << 8);
}
};
template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
index_t expcnt = waitcnt_arg::kMaxExpCnt,
index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
CK_TILE_DEVICE void s_waitcnt()
{
__builtin_amdgcn_s_waitcnt(waitcnt_arg::from_vmcnt<vmcnt>() |
waitcnt_arg::from_expcnt<expcnt>() |
waitcnt_arg::from_lgkmcnt<lgkmcnt>());
}
template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
index_t expcnt = waitcnt_arg::kMaxExpCnt,
index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
CK_TILE_DEVICE void s_waitcnt_barrier()
{
s_waitcnt<vmcnt, expcnt, lgkmcnt>();
__builtin_amdgcn_s_barrier();
}
template <index_t lgkmcnt = 0>
CK_TILE_DEVICE void block_sync_lds()
{
s_waitcnt_barrier<waitcnt_arg::kMaxVmCnt, waitcnt_arg::kMaxExpCnt, lgkmcnt>();
}
template <index_t vmcnt = 0>
CK_TILE_DEVICE void block_sync_lds_direct_load()
{
asm volatile("\
s_waitcnt vmcnt(0) \n \
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
s_waitcnt_barrier<vmcnt, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
}
CK_TILE_DEVICE void s_nop(index_t cnt = 0)
@@ -166,4 +242,35 @@ CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity()
#endif
}
/// Helper function to convert address space enum to string
CK_TILE_HOST_DEVICE constexpr const char* address_space_to_string(address_space_enum addr_space)
{
switch(addr_space)
{
case address_space_enum::generic: return "generic";
case address_space_enum::global: return "global";
case address_space_enum::lds: return "lds";
case address_space_enum::sgpr: return "sgpr";
case address_space_enum::constant: return "constant";
case address_space_enum::vgpr: return "vgpr";
default: return "unknown";
}
}
// Architecture tags
struct gfx11_t
{
};
struct gfx12_t
{
};
CK_TILE_DEVICE static constexpr auto get_device_arch()
{
#if defined(__gfx11__)
return gfx11_t{};
#else // if defined(__gfx12__)
return gfx12_t{};
#endif
}
} // namespace ck_tile

View File

@@ -6,6 +6,10 @@
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#define HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN \
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2f16) && \
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16)
namespace ck_tile {
template <typename T, typename ComputeType>
@@ -40,6 +44,14 @@ CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t& a, const bf16x4_t& b)
return rtn;
}
CK_TILE_HOST_DEVICE fp16x2_t add_f16x2_t(const fp16x2_t& a, const fp16x2_t& b)
{
fp16x2_t rtn;
rtn[0] = add<fp16_t, float>(a[0], b[0]);
rtn[1] = add<fp16_t, float>(a[1], b[1]);
return rtn;
}
CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t& a, const fp8x4_t& b)
{
fp8x4_t rtn;
@@ -343,6 +355,44 @@ CK_TILE_DEVICE void atomic_add<bf8x8_t>(bf8x8_t* p_dst, bf8x8_t const& x)
} while(cur_v.u64 != old_v);
}
//
// Atomic add for fp16x2_t
//
template <>
CK_TILE_DEVICE void atomic_add<fp16x2_t>(fp16x2_t* p_dst, fp16x2_t const& x)
{
#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
__builtin_amdgcn_global_atomic_fadd_v2f16(c_style_pointer_cast<fp16x2_t*>(p_dst), x);
#else
union U32F162_ADDR
{
uint32_t* u32_a;
fp16x2_t* f162_a;
};
union U32F162
{
uint32_t u32;
fp16x2_t f162;
};
U32F162_ADDR dword_addr;
U32F162 cur_v;
U32F162 new_;
uint32_t old_v, new_v;
dword_addr.f162_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.f162 = add_f16x2_t(cur_v.f162, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
#endif
}
template <typename T, index_t N>
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
{
@@ -350,6 +400,7 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
(std::is_same<T, uint32_t>::value && (N == 1)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
(std::is_same<T, double>::value && (N == 1 || N == 2)) ||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, fp8_t>::value && (N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 4 || N == 8 || N == 16)),
@@ -445,6 +496,13 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst) + 1, x.template get_as<bf8x8_t>()[I1]);
}
}
else if constexpr(std::is_same<T, fp16_t>::value)
{
static_for<0, N / 2, 1>{}([&](auto i) {
atomic_add(c_style_pointer_cast<fp16x2_t*>(p_dst) + i,
x.template get_as<fp16x2_t>()[i]);
});
}
}
template <typename T, index_t N>

View File

@@ -59,6 +59,21 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
#endif
}
template <typename T>
CK_TILE_DEVICE auto warp_shuffle_down_pair(const T& v_local)
{
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
const int32x2_t x = __builtin_amdgcn_permlane32_swap(
bit_cast<int32_t>(v_local), bit_cast<int32_t>(v_local), false, false);
thread_buffer<T, 2> v;
v(0) = bit_cast<T>(x[0]);
v(1) = bit_cast<T>(x[1]);
return v;
}
template <typename T>
CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
{

View File

@@ -3,10 +3,11 @@
#pragma once
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) || \
defined(__gfx9_4_generic__)
#define __gfx9__
#endif
#if defined(__gfx942__) || defined(__gfx950__)
#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx9_4_generic__)
#define __gfx94__
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
@@ -152,7 +153,7 @@
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#elif defined(__gfx9__) // for GPU code
#elif defined(__gfx9__) || defined(__gfx12__) // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
@@ -191,6 +192,16 @@
#endif
#endif
// use llvm builtin bf16 data type after ROCm 6.5
#ifndef CK_TILE_USE_LLVM_BUILTIN_BF16
#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 5 && HIP_VERSION_PATCH >= 50421) || \
(HIP_VERSION_MAJOR >= 7)
#define CK_TILE_USE_LLVM_BUILTIN_BF16 1
#else
#define CK_TILE_USE_LLVM_BUILTIN_BF16 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#endif
@@ -253,7 +264,7 @@
#endif
#ifndef CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
#if __clang_major__ == 20
#if __clang_major__ >= 20
#define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 1
#else
#define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 0

View File

@@ -177,9 +177,27 @@ struct array<T, 0>
CK_TILE_HOST_DEVICE constexpr array() {}
CK_TILE_HOST_DEVICE static constexpr index_t size() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<T>; };
CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); }
};
template <typename T, index_t N>
CK_TILE_HOST_DEVICE static void print(const array<T, N>& a)
{
printf("array{size: %ld, data: [", static_cast<long>(N));
for(index_t i = 0; i < N; ++i)
{
if(i > 0)
printf(", ");
print(a[i]);
}
printf("]}");
}
template <typename T>
CK_TILE_HOST_DEVICE static void print(const array<T, 0>&)
{
printf("array{size: 0, data: []}");
}
template <typename, typename>
struct vector_traits;

View File

@@ -16,7 +16,7 @@ template <typename TData, index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const array<TData, NSize>& a, const TData& x)
{
array<TData, NSize + 1> r;
static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; });
static_for<0, NSize, 1>{}([&r, &a](auto i) constexpr { r(i) = a[i]; });
r[number<NSize>{}] = x;
return r;
}

View File

@@ -139,26 +139,21 @@ struct map
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr iterator end() { return iterator{impl_, size_}; }
CK_TILE_HOST_DEVICE void print() const
{
printf("map{size_: %d, ", size_);
//
printf("impl_: [");
//
for(const auto& [k, d] : *this)
{
printf("{key: ");
print(k);
printf(", data: ");
print(d);
printf("}, ");
}
//
printf("]");
//
printf("}");
}
};
template <typename key, typename data, index_t max_size>
CK_TILE_HOST_DEVICE static void print(const map<key, data, max_size>& m)
{
printf("map{size_: %d, impl_: [", m.size_);
for(const auto& [k, d] : m)
{
printf("{key: ");
print(k);
printf(", data: ");
print(d);
printf("}, ");
}
printf("]}");
}
} // namespace ck_tile

View File

@@ -9,13 +9,10 @@
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/print.hpp"
namespace ck_tile {
template <index_t, index_t, index_t>
struct static_for;
template <index_t...>
struct sequence;
@@ -199,15 +196,24 @@ struct sequence
{
return sequence<f(Is)...>{};
}
CK_TILE_HOST_DEVICE static void print()
{
printf("sequence{size: %d, data: [", size());
((printf("%d ", Is)), ...);
printf("]}");
}
};
template <index_t... Is>
CK_TILE_HOST_DEVICE static void print(const sequence<Is...>&)
{
printf("sequence<");
if constexpr(sizeof...(Is) > 0)
{
bool first = true;
(([&first](index_t value) {
printf("%s%d", first ? "" : ", ", value);
first = false;
}(Is)),
...);
}
printf(">");
}
namespace impl {
template <typename T, T... Ints>
struct __integer_sequence;

View File

@@ -42,7 +42,11 @@ struct thread_buffer {
// TODO: this ctor can't ignore
CK_TILE_HOST_DEVICE constexpr thread_buffer() : data{} {}
CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{o} {}
CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{} {
static_for<0, N, 1>{}(
[&](auto i) { data[i] = o; }
);
}
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
CK_TILE_HOST_DEVICE auto & get() {return data; }

View File

@@ -262,12 +262,18 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
return flag;
}
CK_TILE_HOST_DEVICE static constexpr bool IsTuple() { return true; }
#define TP_COM_() static_assert(I < size(), "wrong! out of range")
// clang-format off
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const & { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const & { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() & { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) & { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() && { TP_COM_(); return impl::getv<I>(std::move(*this)); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) && { TP_COM_(); return std::move(*this).template get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const && { TP_COM_(); return impl::getv<I>(std::move(*this)); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const &&{ TP_COM_(); return std::move(*this).template get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at() const { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number<I>) const { TP_COM_(); return get<I>(); }
@@ -294,12 +300,29 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
#undef TP_COM_
};
template <typename, typename = void>
template <typename... T>
CK_TILE_HOST_DEVICE void print(const tuple<T...>& t)
{
printf("tuple<");
if constexpr(sizeof...(T) > 0)
{
bool first = true;
static_for<0, sizeof...(T), 1>{}([&t, &first](auto i) {
if(!first)
printf(", ");
print(t.get(i));
first = false;
});
}
printf(">");
}
template <typename, typename>
struct vector_traits;
// specialization for array
template <typename... T>
struct vector_traits<tuple<T...>>
struct vector_traits<tuple<T...>, void>
{
using scalar_type = __type_pack_element<0, T...>;
static constexpr index_t vector_size = sizeof...(T);
@@ -470,6 +493,12 @@ transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, sequence<Is...>)
return make_tuple(f(x.at(number<Is>{}), y.at(number<Is>{}), z.at(number<Is>{}))...);
}
template <typename F, typename Tuple, index_t... Is>
constexpr decltype(auto) apply_impl(F&& f, Tuple&& t, sequence<Is...>)
{
return std::forward<F>(f)(std::forward<Tuple>(t).get(number<Is>{})...);
}
} // namespace detail
template <typename F, typename X>
@@ -493,6 +522,13 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
template <typename F, typename Tuple>
constexpr decltype(auto) apply(F&& f, Tuple&& t)
{
constexpr index_t N = std::decay_t<Tuple>::size();
return detail::apply_impl(std::forward<F>(f), std::forward<Tuple>(t), make_index_sequence<N>{});
}
namespace detail {
template <typename F, typename X, index_t... Is>

View File

@@ -6,6 +6,9 @@
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#if CK_TILE_USE_LLVM_BUILTIN_BF16
#include <hip/hip_bfloat16.h>
#endif
#include <stdint.h>
#pragma once
@@ -102,7 +105,11 @@ struct native_t<bfloat16_t>
using bf16_t = bfloat16_t;
using bf16_raw_t = typename bf16_t::raw_type;
#else
#if CK_TILE_USE_LLVM_BUILTIN_BF16
using bfloat16_t = __bf16;
#else
using bfloat16_t = ushort;
#endif
using bf16_t = bfloat16_t;
using bf16_raw_t = uint16_t;
#endif
@@ -280,7 +287,11 @@ template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
{
#if defined(__gfx950__)
return static_cast<bfloat16_t>(f);
#else
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
#endif
}
template <bf16_rounding_mode rounding =

View File

@@ -31,9 +31,10 @@ struct e8m0_bexp_t
raw_type data;
CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t() : data{type{0b11111111}} {}
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(float scale) : data(0)
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(type init) : data{init} {}
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(float scale)
: e8m0_bexp_t(static_cast<type>(numeric_utils<float>::get_exponent(scale)))
{
data = numeric_utils<float>::get_exponent(scale);
}
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }
@@ -86,7 +87,7 @@ CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t::operator float() const
using traits = numeric_traits<float>;
if(data == numeric<e8m0_t>::binary_nan)
{
return traits::NaN;
return std::numeric_limits<float>::signaling_NaN();
}
else if(data == 0)
{

View File

@@ -43,19 +43,19 @@ enum class fp8_interpretation
};
/*
* ______________FNUZ_________________ | ______________OCP________________
* ______________FNUZ_________________ | ______________OCP________________
* e4m3 e5m2 | e4m3 e5m2
* bias : 8 16 | 7 15
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
* inf : N/A N/A | N/A s.11111.00
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111 s.00000.11
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
* 2^-7(0.0078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
* Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01
* 2^-10(0.00097656) 2^-17(7.629395e-06)| 2^-9(0.001953125) 2^-16(1.52588e-05)
* 2^-10(0.0009765625) 2^-17(7.62939e-06) | 2^-9(0.001953125) 2^-16(1.52588e-05)
*/
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
@@ -75,7 +75,7 @@ struct alignas(1) float8_e4m3_t
#if CK_TILE_USE_OCP_FP8
static constexpr int bias = 7; // OCP
#else
static constexpr int bias = 8; // FNUZ
static constexpr int bias = 8; // FNUZ
#endif
using raw_type = uint8_t;
raw_type data;
@@ -259,50 +259,50 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
// fp8/bf8 type exponent/mantissa layout
constexpr int DstT_exp = numeric_traits<DstT>::exp; // exponent width of the destination type
constexpr int DstT_mant = numeric_traits<DstT>::mant; // mantissa width of the destination type
constexpr int DstT_bias = numeric_traits<DstT>::bias;
constexpr bool is_fnuz =
(numeric_traits<DstT>::f8_interpret == fp8_interpretation::E4M3_FNUZ) ||
(numeric_traits<DstT>::f8_interpret == fp8_interpretation::E5M2_FNUZ);
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
constexpr int bias = numeric_traits<SrcT>::bias;
constexpr unsigned int fInf = numeric_traits<SrcT>::Inf;
constexpr unsigned int abs_mask = numeric_traits<SrcT>::abs_mask;
using SrcT_bitwise = typename numeric_traits<SrcT>::bitwise_type;
SrcT_bitwise src_bitwise = bit_cast<SrcT_bitwise>(src);
unsigned long long head, mantissa;
int exponent, bias;
unsigned int head, mantissa;
int exponent;
unsigned int sign;
unsigned long long fInf, abs_mask;
head = src_bitwise & numeric_traits<SrcT>::head_mask;
mantissa = src_bitwise & numeric_traits<SrcT>::mant_mask;
exponent = (head >> SrcT_mant) & numeric_traits<SrcT>::exp_mask;
sign = head >> (SrcT_exp + SrcT_mant);
bias = numeric_traits<SrcT>::bias;
fInf = numeric_traits<SrcT>::Inf;
abs_mask = numeric_traits<SrcT>::abs_mask;
unsigned int signed_inf = 0;
unsigned int nan = 0;
if constexpr(is_fnuz)
{
signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
signed_inf = clip ? ((sign << (DstT_exp + DstT_mant)) + 0x7f) : 0x80;
nan = 0x80;
}
else
{
if constexpr(DstT_exp == 4)
{ // e4m3
signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7e : 0x7f);
}
else
{ // e5m2
signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7b : 0x7c);
}
nan = (sign << 7) + 0x7f;
nan = (sign << (DstT_exp + DstT_mant)) + 0x7f;
}
// Max values
unsigned long long ifmax = 0;
unsigned int ifmax = 0;
if constexpr(is_float)
{
if constexpr(DstT_exp == 5)
@@ -343,9 +343,6 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
// Deal with inf and NaNs
if((src_bitwise & fInf) == fInf)
{
if constexpr(is_fnuz)
return signed_inf;
return mantissa != 0 ? nan : signed_inf;
}
@@ -354,11 +351,6 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
return signed_inf;
}
if(src_bitwise == 0)
{
return 0;
}
// First need to check if it is normal or denorm as there is a difference of
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
@@ -367,8 +359,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
// bits
const int f8_bias = (1 << (DstT_exp - 1)) - 1 + (is_fnuz ? 1 : 0);
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
constexpr int f8_denormal_act_exponent = 1 - DstT_bias; // actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
@@ -406,11 +397,16 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
// for this case, act_exponent could be larger. Just
// that it does not need shift mantissa
}
mantissa += (1ull << SrcT_mant); // Add the implicit 1 into mantissa
mantissa += (1u << SrcT_mant); // Add the implicit 1 into mantissa
}
bool midpoint = (mantissa & ((1ull << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
(1ull << (SrcT_mant - DstT_mant + exponent_diff - 1));
// The value is smaller than min f8 denormal and results in zero (the early exit also prevents
// an undefined behavior of bit shifts >= type width).
if(exponent_diff > DstT_mant)
{
return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
}
bool midpoint = (mantissa & ((1u << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
(1u << (SrcT_mant - DstT_mant + exponent_diff - 1));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
done before we shift right as shift right could rip off some residual part and
make something not midpoint look like midpoint. For example, the fp16 number
@@ -422,31 +418,31 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
mantissa >>= exponent_diff;
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1ull << SrcT_mant);
bool implicit_one = mantissa & (1u << SrcT_mant);
// if there is no implicit 1, it means the f8 is denormal and need to adjust
// to denorm exponent
f8_exponent =
(act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
(act_exponent + exponent_diff) /*actual f8 exponent*/ + DstT_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted
unsigned long long drop_mask = (1ull << (SrcT_mant - DstT_mant)) - 1;
unsigned int drop_mask = (1u << (SrcT_mant - DstT_mant)) - 1;
bool odd =
mantissa & (1ull << (SrcT_mant -
DstT_mant)); // if the least significant bit that is not truncated is 1
mantissa &
(1u << (SrcT_mant - DstT_mant)); // if the least significant bit that is not truncated is 1
mantissa +=
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1u) : mantissa)) & drop_mask;
// Now we deal with overflow
if(f8_exponent == 0)
{
if((1ull << SrcT_mant) & mantissa)
if((1u << SrcT_mant) & mantissa)
{
f8_exponent = 1; // denormal overflow to become normal, promote exponent
}
}
else
{
if((1ull << (SrcT_mant + 1)) & mantissa)
if((1u << (SrcT_mant + 1)) & mantissa)
{
mantissa >>= 1;
f8_exponent++;
@@ -471,9 +467,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
}
if(f8_exponent == 0 && mantissa == 0)
return is_fnuz ? 0 : (sign << 7);
return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
mantissa &= (1 << DstT_mant) - 1;
return (sign << 7) | (f8_exponent << DstT_mant) | mantissa;
return (sign << (DstT_exp + DstT_mant)) | (f8_exponent << DstT_mant) | mantissa;
}
template <typename SrcT, typename DstT, bool clip = true>
@@ -481,8 +477,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
{
static_assert(std::is_same<SrcT, fp8_t>::value || std::is_same<SrcT, bf8_t>::value,
"SrcT type must be fp8 or bf8.");
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
constexpr uint8_t SrcT_abs_mask = numeric_traits<SrcT>::abs_mask;
constexpr bool is_fnuz =
(numeric_traits<SrcT>::f8_interpret == fp8_interpretation::E4M3_FNUZ) ||
(numeric_traits<SrcT>::f8_interpret == fp8_interpretation::E5M2_FNUZ);
@@ -518,9 +515,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
return 0;
}
unsigned long long sign = x >> 7;
unsigned long long mantissa = x & ((1 << SrcT_mant) - 1);
int exponent = (x & 0x7F) >> SrcT_mant;
unsigned int sign = x >> (SrcT_exp + SrcT_mant);
unsigned int mantissa = x & ((1 << SrcT_mant) - 1);
int exponent = (x & SrcT_abs_mask) >> SrcT_mant;
if constexpr(is_fnuz)
{
if((x & 0xff) == 0x80)
@@ -559,7 +556,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
if constexpr(SrcT_exp == 5 && is_half && !is_fnuz)
{
retval = x << 8;
retval = static_cast<typename numeric_traits<DstT>::bitwise_type>(x) << 8;
return bit_cast<DstT>(retval);
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -19,14 +19,18 @@ struct constant
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
};
template <auto v>
CK_TILE_HOST_DEVICE static void print(const constant<v>&)
{
printf("%ld", static_cast<long>(v));
}
template <typename T, T v>
struct integral_constant : constant<v>
{
using value_type = T;
using type = integral_constant; // using injected-class-name
static constexpr T value = v;
// constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
// constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } //
};
template <index_t v>
@@ -79,4 +83,14 @@ CK_TILE_BINARY_OP(<=)
#undef CK_TILE_LEFT_UNARY_OP
#undef CK_TILE_BINARY_OP
template <typename T>
struct is_constant : std::false_type
{
};
template <auto v>
struct is_constant<constant<v>> : std::true_type
{
};
template <typename T>
inline constexpr bool is_constant_v = is_constant<T>::value;
} // namespace ck_tile

View File

@@ -31,8 +31,8 @@ struct scales
CK_TILE_HOST_DEVICE constexpr explicit scales(Scale lhs) : lhs_(lhs) {}
template <typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const
-> decltype(std::declval<const Scale&>() * rhs)
CK_TILE_HOST_DEVICE constexpr auto
operator()(const Right& rhs) const -> decltype(std::declval<const Scale&>() * rhs)
{
return lhs_ * rhs;
}
@@ -43,13 +43,13 @@ struct scales
/// FIXME: create macro to replace '__host__ __device__' and nothing more
template <typename Scale>
__host__ __device__ scales(Scale)->scales<Scale>;
__host__ __device__ scales(Scale) -> scales<Scale>;
template <typename Left = void, typename Right = Left>
struct plus
{
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs + rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs + rhs)
{
return lhs + rhs;
}
@@ -59,21 +59,21 @@ template <>
struct plus<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs + rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs + rhs)
{
return lhs + rhs;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ plus()->plus<void, void>;
__host__ __device__ plus() -> plus<void, void>;
template <typename Left = void, typename Right = Left>
struct minus
{
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs - rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs - rhs)
{
return lhs - rhs;
}
@@ -83,21 +83,21 @@ template <>
struct minus<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs - rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs - rhs)
{
return lhs - rhs;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ minus()->minus<void, void>;
__host__ __device__ minus() -> minus<void, void>;
template <typename Left = void, typename Right = Left>
struct multiplies
{
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs * rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs * rhs)
{
return lhs * rhs;
}
@@ -107,15 +107,15 @@ template <>
struct multiplies<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs * rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs * rhs)
{
return lhs * rhs;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ multiplies()->multiplies<void, void>;
__host__ __device__ multiplies() -> multiplies<void, void>;
template <typename T>
struct maximize
@@ -327,8 +327,8 @@ CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Ys... ys)
template <typename Left = void, typename Right = Left>
struct equal
{
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs == rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs == rhs)
{
return lhs == rhs;
}
@@ -338,15 +338,15 @@ template <>
struct equal<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs == rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs == rhs)
{
return lhs == rhs;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ equal()->equal<void, void>;
__host__ __device__ equal() -> equal<void, void>;
template <>
struct equal<float, float>
@@ -369,8 +369,8 @@ struct equal<double, double>
template <typename Left = void, typename Right = Left>
struct less
{
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs < rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs < rhs)
{
return lhs < rhs;
}
@@ -380,21 +380,21 @@ template <>
struct less<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs < rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs < rhs)
{
return lhs < rhs;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ less()->less<void, void>;
__host__ __device__ less() -> less<void, void>;
template <typename Left = void, typename Right = Left>
struct less_equal
{
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs <= rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs <= rhs)
{
return lhs <= rhs;
}
@@ -404,15 +404,15 @@ template <>
struct less_equal<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs <= rhs)
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
const Right& rhs) const -> decltype(lhs <= rhs)
{
return lhs <= rhs;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ less_equal()->less_equal<void, void>;
__host__ __device__ less_equal() -> less_equal<void, void>;
template <>
struct less_equal<float, float>

View File

@@ -21,7 +21,7 @@ namespace ck_tile {
using fp32_t = float;
using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float x, float scale = 1.f);
@@ -61,8 +61,8 @@ struct pk_float4_e2m1_t
CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); }
template <index_t I>
CK_TILE_HOST_DEVICE raw_type unpack(number<I>) const;
CK_TILE_HOST_DEVICE static pk_float4_e2m1_t pack(const type x0, const type x1)
CK_TILE_HOST_DEVICE constexpr raw_type unpack(number<I>) const;
CK_TILE_HOST_DEVICE constexpr static pk_float4_e2m1_t pack(const type x0, const type x1)
{
return (x1 << 4) | (x0 & 0b00001111);
}
@@ -136,7 +136,7 @@ struct numeric<pk_fp4_t>
};
template <index_t I>
CK_TILE_HOST_DEVICE pk_fp4_raw_t pk_fp4_t::unpack(number<I>) const
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t pk_fp4_t::unpack(number<I>) const
{
static_assert(I < 2, "Index is out of range.");
if constexpr(I == 1)
@@ -153,7 +153,6 @@ namespace impl {
template <typename T>
CK_TILE_DEVICE T _from_f4(pk_fp4_raw_t src, float scale = 1.0f)
{
// TODO: check the order
if constexpr(std::is_same_v<T, fp32_t>)
return fp32x2_t(__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0))[0];
else if constexpr(std::is_same_v<T, fp32x2_t>)
@@ -173,7 +172,6 @@ CK_TILE_DEVICE T _from_f4(pk_fp4_raw_t src, float scale = 1.0f)
template <typename T>
CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f)
{
// TODO: check the order
union
{
uint32_t u32;

View File

@@ -99,7 +99,8 @@ struct numeric_traits<pk_int4_t>
using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
using int8x2_t = int8_t __attribute__((ext_vector_type(2)));
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
{
@@ -116,6 +117,24 @@ CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
return res;
}
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
float x_l = ((x_u8 & 0x0f) >> 0);
float x_h = ((x_u8 & 0xf0) >> 4);
x_l = x_l > 7 ? x_l - 16 : x_l;
x_h = x_l > 7 ? x_l - 16 : x_l;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
fp32x2_t res = {x_h, x_l};
#elif
fp32x2_t res = {x_l, x_h};
#endif
return res;
}
CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
@@ -147,4 +166,24 @@ CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x)
return res;
}
CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
int8_t x_l = (x_u8 & 0x0F);
int8_t x_h = (x_u8 & 0xF0) >> 4;
if(x_l & 0x08)
x_l |= 0xF0;
if(x_h & 0x08)
x_h |= 0xF0;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
int8x2_t res = {x_h, x_l};
#else
int8x2_t res = {x_l, x_h};
#endif
return res;
}
} // namespace ck_tile

View File

@@ -85,7 +85,7 @@ using ext_vector_t = typename impl::ext_vector<T, N>::type;
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
// ... unless we have other vector_traits specialization
template <typename T, typename>
template <typename T, typename = void>
struct vector_traits
{
using scalar_type =
@@ -101,7 +101,7 @@ struct vector_traits
// specialization for ext_vector_type()
template <typename T, index_t N>
struct vector_traits<T __attribute__((ext_vector_type(N)))>
struct vector_traits<T __attribute__((ext_vector_type(N))), void>
{
using scalar_type = std::conditional_t<
std::is_same_v<T, pk_int4_t>,
@@ -144,12 +144,12 @@ using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
// bf16
// using bf16_t = ...
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4)));
using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8)));
using bf16x16_t = bf16_raw_t __attribute__((ext_vector_type(16)));
using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32)));
using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
using bf16x4_t = bfloat16_t __attribute__((ext_vector_type(4)));
using bf16x8_t = bfloat16_t __attribute__((ext_vector_type(8)));
using bf16x16_t = bfloat16_t __attribute__((ext_vector_type(16)));
using bf16x32_t = bfloat16_t __attribute__((ext_vector_type(32)));
using bf16x64_t = bfloat16_t __attribute__((ext_vector_type(64)));
using cktile_llvm_bf16x2_t = __bf16 __attribute__((ext_vector_type(2)));
using cktile_llvm_bf16x4_t = __bf16 __attribute__((ext_vector_type(4)));
@@ -192,58 +192,58 @@ using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
// i8
// using int8_t
using int8x2_t = int8_t __attribute((ext_vector_type(2)));
using int8x4_t = int8_t __attribute((ext_vector_type(4)));
using int8x8_t = int8_t __attribute((ext_vector_type(8)));
using int8x16_t = int8_t __attribute((ext_vector_type(16)));
using int8x32_t = int8_t __attribute((ext_vector_type(32)));
using int8x64_t = int8_t __attribute((ext_vector_type(64)));
using int8x2_t = int8_t __attribute__((ext_vector_type(2)));
using int8x4_t = int8_t __attribute__((ext_vector_type(4)));
using int8x8_t = int8_t __attribute__((ext_vector_type(8)));
using int8x16_t = int8_t __attribute__((ext_vector_type(16)));
using int8x32_t = int8_t __attribute__((ext_vector_type(32)));
using int8x64_t = int8_t __attribute__((ext_vector_type(64)));
// ui8
// using uint8_t
using uint8x2_t = uint8_t __attribute((ext_vector_type(2)));
using uint8x4_t = uint8_t __attribute((ext_vector_type(4)));
using uint8x8_t = uint8_t __attribute((ext_vector_type(8)));
using uint8x16_t = uint8_t __attribute((ext_vector_type(16)));
using uint8x32_t = uint8_t __attribute((ext_vector_type(32)));
using uint8x64_t = uint8_t __attribute((ext_vector_type(64)));
using uint8x2_t = uint8_t __attribute__((ext_vector_type(2)));
using uint8x4_t = uint8_t __attribute__((ext_vector_type(4)));
using uint8x8_t = uint8_t __attribute__((ext_vector_type(8)));
using uint8x16_t = uint8_t __attribute__((ext_vector_type(16)));
using uint8x32_t = uint8_t __attribute__((ext_vector_type(32)));
using uint8x64_t = uint8_t __attribute__((ext_vector_type(64)));
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// f8
// using fp8_t
using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2)));
using fp8x4_t = fp8_raw_t __attribute((ext_vector_type(4)));
using fp8x8_t = fp8_raw_t __attribute((ext_vector_type(8)));
using fp8x16_t = fp8_raw_t __attribute((ext_vector_type(16)));
using fp8x32_t = fp8_raw_t __attribute((ext_vector_type(32)));
using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64)));
using fp8x2_t = fp8_raw_t __attribute__((ext_vector_type(2)));
using fp8x4_t = fp8_raw_t __attribute__((ext_vector_type(4)));
using fp8x8_t = fp8_raw_t __attribute__((ext_vector_type(8)));
using fp8x16_t = fp8_raw_t __attribute__((ext_vector_type(16)));
using fp8x32_t = fp8_raw_t __attribute__((ext_vector_type(32)));
using fp8x64_t = fp8_raw_t __attribute__((ext_vector_type(64)));
// bf8
// using bf8_t
using bf8x2_t = bf8_raw_t __attribute((ext_vector_type(2)));
using bf8x4_t = bf8_raw_t __attribute((ext_vector_type(4)));
using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8)));
using bf8x16_t = bf8_raw_t __attribute((ext_vector_type(16)));
using bf8x32_t = bf8_raw_t __attribute((ext_vector_type(32)));
using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64)));
using bf8x2_t = bf8_raw_t __attribute__((ext_vector_type(2)));
using bf8x4_t = bf8_raw_t __attribute__((ext_vector_type(4)));
using bf8x8_t = bf8_raw_t __attribute__((ext_vector_type(8)));
using bf8x16_t = bf8_raw_t __attribute__((ext_vector_type(16)));
using bf8x32_t = bf8_raw_t __attribute__((ext_vector_type(32)));
using bf8x64_t = bf8_raw_t __attribute__((ext_vector_type(64)));
#else
// f8
// using fp8_t
using fp8x2_t = fp8_t __attribute((ext_vector_type(2)));
using fp8x4_t = fp8_t __attribute((ext_vector_type(4)));
using fp8x8_t = fp8_t __attribute((ext_vector_type(8)));
using fp8x16_t = fp8_t __attribute((ext_vector_type(16)));
using fp8x32_t = fp8_t __attribute((ext_vector_type(32)));
using fp8x64_t = fp8_t __attribute((ext_vector_type(64)));
using fp8x2_t = fp8_t __attribute__((ext_vector_type(2)));
using fp8x4_t = fp8_t __attribute__((ext_vector_type(4)));
using fp8x8_t = fp8_t __attribute__((ext_vector_type(8)));
using fp8x16_t = fp8_t __attribute__((ext_vector_type(16)));
using fp8x32_t = fp8_t __attribute__((ext_vector_type(32)));
using fp8x64_t = fp8_t __attribute__((ext_vector_type(64)));
// bf8
// using bf8_t
using bf8x2_t = bf8_t __attribute((ext_vector_type(2)));
using bf8x4_t = bf8_t __attribute((ext_vector_type(4)));
using bf8x8_t = bf8_t __attribute((ext_vector_type(8)));
using bf8x16_t = bf8_t __attribute((ext_vector_type(16)));
using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
using bf8x64_t = bf8_t __attribute((ext_vector_type(64)));
using bf8x2_t = bf8_t __attribute__((ext_vector_type(2)));
using bf8x4_t = bf8_t __attribute__((ext_vector_type(4)));
using bf8x8_t = bf8_t __attribute__((ext_vector_type(8)));
using bf8x16_t = bf8_t __attribute__((ext_vector_type(16)));
using bf8x32_t = bf8_t __attribute__((ext_vector_type(32)));
using bf8x64_t = bf8_t __attribute__((ext_vector_type(64)));
#endif
// pk_int4_t

View File

@@ -5,7 +5,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#if __clang_major__ == 20
#if __clang_major__ >= 20
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
#else
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
@@ -62,12 +62,12 @@ struct buffer_view<address_space_enum::generic,
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
BufferSizeType buffer_size,
T invalid_element_value)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
@@ -210,28 +210,6 @@ struct buffer_view<address_space_enum::generic,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: generic, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
// Address Space: Global
@@ -265,7 +243,7 @@ struct buffer_view<address_space_enum::global,
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
: p_data_{p_data},
buffer_size_{buffer_size / PackedSize},
cached_buf_res_{0},
@@ -273,7 +251,7 @@ struct buffer_view<address_space_enum::global,
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
BufferSizeType buffer_size,
T invalid_element_value)
: p_data_{p_data},
@@ -767,28 +745,6 @@ struct buffer_view<address_space_enum::global,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: Global, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
// Address Space: LDS
@@ -816,12 +772,12 @@ struct buffer_view<address_space_enum::lds,
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
BufferSizeType buffer_size,
T invalid_element_value)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
@@ -1004,51 +960,34 @@ struct buffer_view<address_space_enum::lds,
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128
// TODO: remove this after compiler fix
// clang-format off
static_assert(
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x4_t> &&
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> &&
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x4_t> && std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x8_t> && std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> && std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
// int8 on thread buffer
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 16>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
// ext_vector_type for pk_int4 must use int8_t as type
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x16_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>),
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x16_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>),
"wrong! not implemented for this combination, please add "
"implementation");
// clang-format on
if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
@@ -1100,6 +1039,8 @@ struct buffer_view<address_space_enum::lds,
}
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 16>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>))
{
@@ -1163,28 +1104,6 @@ struct buffer_view<address_space_enum::lds,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: Lds, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
// Address Space: Vgpr
@@ -1212,12 +1131,12 @@ struct buffer_view<address_space_enum::vgpr,
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
BufferSizeType buffer_size,
T invalid_element_value)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
@@ -1338,35 +1257,13 @@ struct buffer_view<address_space_enum::vgpr,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: Vgpr, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
template <address_space_enum BufferAddressSpace,
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
typename T,
typename BufferSizeType>
CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T* p, BufferSizeType buffer_size)
CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T* __restrict__ p, BufferSizeType buffer_size)
{
return buffer_view<BufferAddressSpace, T, BufferSizeType, true, Coherence>{p, buffer_size};
}
@@ -1379,10 +1276,31 @@ template <address_space_enum BufferAddressSpace,
typename std::enable_if<std::is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto
make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value)
make_buffer_view(T* __restrict__ p, BufferSizeType buffer_size, X invalid_element_value)
{
return buffer_view<BufferAddressSpace, T, BufferSizeType, false, Coherence>{
p, buffer_size, invalid_element_value};
}
// Generalized print function for all buffer_view variants
template <address_space_enum BufferAddressSpace,
typename T,
typename BufferSizeType,
bool InvalidElementUseNumericalZeroValue,
amd_buffer_coherence_enum Coherence>
CK_TILE_HOST_DEVICE void print(const buffer_view<BufferAddressSpace,
T,
BufferSizeType,
InvalidElementUseNumericalZeroValue,
Coherence>& bv)
{
printf("buffer_view{AddressSpace: %s, p_data_: %p, buffer_size_: ",
address_space_to_string(BufferAddressSpace),
static_cast<void*>(const_cast<remove_cvref_t<T>*>(bv.p_data_)));
print(bv.buffer_size_);
printf(", invalid_element_value_: ");
print(bv.invalid_element_value_);
printf("}");
}
} // namespace ck_tile

View File

@@ -17,6 +17,11 @@
namespace ck_tile {
constexpr int DS_READ_TR_SIZE()
{
return 8; // Literal constant, evaluated at compile time
}
namespace util {
template <typename Suffix, typename Sequence>
struct is_sequence_suffix
@@ -45,48 +50,60 @@ constexpr bool is_sequence_suffix_v = is_sequence_suffix<Suffix, Sequence>::valu
template <typename DataType>
struct DefaultTranspose
{
template <index_t LaneGroupSize>
struct Quad16
{
using InputEncoding = tile_distribution_encoding<sequence<>,
tuple<sequence<4>, sequence<4, 4>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16,
"LaneGroupSize must be 64, 32, or 16");
using InputEncoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<4>, sequence<LaneGroupSize / 16, 4, 4>>,
tuple<sequence<2, 1, 2>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<2>>;
using OutputEncoding = tile_distribution_encoding<sequence<>,
tuple<sequence<16>, sequence<4>>,
tuple<sequence<1>>,
tuple<sequence<0>>,
sequence<2>,
sequence<0>>;
using OutputEncoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<LaneGroupSize>, sequence<4>>,
tuple<sequence<1>>,
tuple<sequence<0>>,
sequence<2>,
sequence<0>>;
};
template <index_t LaneGroupSize>
struct Quad8
{
using InputEncoding = tile_distribution_encoding<sequence<>,
tuple<sequence<8>, sequence<2, 8>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16,
"LaneGroupSize must be 64, 32, or 16");
using InputEncoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<8>, sequence<LaneGroupSize / 16, 2, 8>>,
tuple<sequence<2, 1, 2>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<2>>;
using OutputEncoding = tile_distribution_encoding<sequence<>,
tuple<sequence<16>, sequence<8>>,
tuple<sequence<1>>,
tuple<sequence<0>>,
sequence<2>,
sequence<0>>;
using OutputEncoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<LaneGroupSize>, sequence<8>>,
tuple<sequence<1>>,
tuple<sequence<0>>,
sequence<2>,
sequence<0>>;
};
// Select based on data size
template <index_t LaneGroupSize>
using QuadInputEncoding = std::conditional_t<sizeof(DataType) == 2,
typename Quad16::InputEncoding,
typename Quad8::InputEncoding>;
typename Quad16<LaneGroupSize>::InputEncoding,
typename Quad8<LaneGroupSize>::InputEncoding>;
template <index_t LaneGroupSize>
using QuadOutputEncoding = std::conditional_t<sizeof(DataType) == 2,
typename Quad16::OutputEncoding,
typename Quad8::OutputEncoding>;
typename Quad16<LaneGroupSize>::OutputEncoding,
typename Quad8<LaneGroupSize>::OutputEncoding>;
// Always swap last two dimensions
static constexpr auto transpose_dims = sequence<1, 0>{};
@@ -96,51 +113,79 @@ struct DefaultTranspose
return idx; // Identity mapping
};
template <typename InDstrEncode>
struct ValidationTraits
template <typename InDstrEncode, bool ReverseDirection, index_t LaneGroupSize>
struct ValidationTraitsImpl
{
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
static constexpr auto quad_hs_lengthss = QuadInputEncoding::hs_lengthss_;
using QuadEncoding = std::conditional_t<ReverseDirection,
QuadOutputEncoding<LaneGroupSize>,
QuadInputEncoding<LaneGroupSize>>;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto input_hs = InDstrEncode::hs_lengthss_;
static constexpr auto quad_hs = QuadEncoding::hs_lengthss_;
// 1. Must be 2D tensor
static constexpr bool dims_valid = (InDstrEncode::NDimX == 2);
// 2. Quad pattern must be suffix of input pattern
static constexpr bool suffix_valid_dim0 =
util::is_sequence_suffix_v<decltype(quad_hs_lengthss.template get<0>()),
decltype(input_hs_lengthss.template get<0>())>;
util::is_sequence_suffix_v<decltype(quad_hs[I0]), decltype(input_hs[I0])>;
static constexpr bool suffix_valid_dim1 =
util::is_sequence_suffix_v<decltype(quad_hs_lengthss.template get<1>()),
decltype(input_hs_lengthss.template get<1>())>;
util::is_sequence_suffix_v<decltype(quad_hs[I1]), decltype(input_hs[I1])>;
// 3. PS→RHS mapping constraints
static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
static constexpr auto input_ps_major = InDstrEncode::ps_to_rhss_major_;
static constexpr auto input_ps_minor = InDstrEncode::ps_to_rhss_minor_;
static constexpr index_t ndimp_outer = input_ps_to_rhss_major.size() - 1;
static constexpr index_t ndimp_inner =
input_ps_to_rhss_major[number<ndimp_outer>{}].size() - 1;
static constexpr auto quad_ps_major0 = QuadEncoding::ps_to_rhss_major_[I0];
static constexpr auto quad_ps_minor0 = QuadEncoding::ps_to_rhss_minor_[I0];
static constexpr auto input_ps_major_last =
input_ps_major[number<input_ps_major.size() - 1>{}];
static constexpr auto input_ps_minor_last =
input_ps_minor[number<input_ps_minor.size() - 1>{}];
using psys_offset = ck_tile::sequence<input_hs[I0].size() - quad_hs[I0].size(),
input_hs[I1].size() - quad_hs[I1].size()>;
static constexpr auto shifted_quad_ps_minor0 = generate_sequence_v2(
[](auto i) {
return number<quad_ps_minor0[i] + psys_offset{}[quad_ps_major0[i] - 1]>{};
},
number<quad_ps_minor0.size()>{});
static constexpr bool ps_mapping_valid =
(input_ps_to_rhss_major[number<ndimp_outer>{}][number<ndimp_inner>{}] == 2) &&
(input_ps_to_rhss_minor[number<ndimp_outer>{}][number<ndimp_inner>{}] ==
input_hs_lengthss[number<1>{}].size() - 2) &&
(input_ps_to_rhss_major[number<ndimp_outer>{}][number<ndimp_inner - 1>{}] == 1) &&
(input_ps_to_rhss_minor[number<ndimp_outer>{}][number<ndimp_inner - 1>{}] ==
input_hs_lengthss[number<0>{}].size() - 1);
util::is_sequence_suffix_v<decltype(quad_ps_major0), decltype(input_ps_major_last)> &&
util::is_sequence_suffix_v<decltype(shifted_quad_ps_minor0),
decltype(input_ps_minor_last)>;
// 4. YS→RHS mapping constraints
static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
static constexpr auto input_ys_major = InDstrEncode::ys_to_rhs_major_;
static constexpr auto input_ys_minor = InDstrEncode::ys_to_rhs_minor_;
static constexpr auto quad_ys_major = QuadEncoding::ys_to_rhs_major_;
static constexpr auto quad_ys_minor = QuadEncoding::ys_to_rhs_minor_;
static_assert(quad_ys_major.size() == 1 && quad_ys_minor.size() == 1,
"YS->RHS mapping must be single dimension");
static_assert(quad_ys_major.back() == 2 && quad_ys_minor.back() == quad_hs[I1].size() - 1,
"YS->RHS mapping must be the last dimension");
static constexpr bool ys_mapping_valid =
(input_ys_to_rhs_major.back() == 2) &&
(input_ys_to_rhs_minor.back() == input_hs_lengthss[number<1>{}].size() - 1) &&
(input_ys_to_rhs_major[input_ys_to_rhs_major.size() - 2] == 1) &&
(input_ys_to_rhs_minor[input_ys_to_rhs_minor.size() - 2] ==
input_hs_lengthss[number<0>{}].size() - 2);
(input_ys_major.back() == 2) && (input_ys_minor.back() == input_hs[I1].size() - 1);
static constexpr bool value = dims_valid && suffix_valid_dim0 && suffix_valid_dim1 &&
ps_mapping_valid && ys_mapping_valid;
};
template <typename InDstrEncode, bool ReverseDirection = false>
struct ValidationTraits
{
static constexpr bool value =
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 64>::value ||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 32>::value ||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 16>::value;
static constexpr index_t LaneGroupSize =
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 64>::value ? 64
: ValidationTraitsImpl<InDstrEncode, ReverseDirection, 32>::value ? 32
: ValidationTraitsImpl<InDstrEncode, ReverseDirection, 16>::value ? 16
: 0;
};
};
template <typename TileDistribution_, typename DataType_, typename Policy>
struct TransposeTileDistrChecker
@@ -154,111 +199,150 @@ struct TransposeTileDistrChecker
// this is used to generate the transposed output tile distribution encoding
// based on the input tile distribution encoding
template <typename TileDistribution_,
template <typename TileDistributionEncoding_,
typename DataType_,
typename Policy = DefaultTranspose<DataType_>>
struct OutputTileDistributionTraits
typename Policy = DefaultTranspose<DataType_>,
bool ReverseDirection = false>
struct TransposeTileDistributionTraits
{
using InDstrEncode = typename remove_cvref_t<TileDistribution_>::DstrEncode;
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
static constexpr auto quad_input_hs_lengthss = Policy::QuadInputEncoding::hs_lengthss_;
static constexpr auto quad_output_hs_lengthss = Policy::QuadOutputEncoding::hs_lengthss_;
using InDstrEncode = remove_cvref_t<TileDistributionEncoding_>;
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
static constexpr index_t LaneGroupSize =
Policy::template ValidationTraits<InDstrEncode, ReverseDirection>::LaneGroupSize;
static_assert(Policy::template ValidationTraits<InDstrEncode, ReverseDirection>::value,
"The input tile distribution encoding is not valid for transpose!");
using QuadInputEncoding = std::conditional_t< //
ReverseDirection,
typename Policy::template QuadOutputEncoding<LaneGroupSize>,
typename Policy::template QuadInputEncoding<LaneGroupSize>>;
using QuadOutputEncoding = std::conditional_t< //
ReverseDirection,
typename Policy::template QuadInputEncoding<LaneGroupSize>,
typename Policy::template QuadOutputEncoding<LaneGroupSize>>;
static constexpr auto quad_input_hs_lengthss = QuadInputEncoding::hs_lengthss_;
static constexpr auto quad_output_hs_lengthss = QuadOutputEncoding::hs_lengthss_;
static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
static constexpr auto quad_ps_to_rhss_major = Policy::QuadInputEncoding::ps_to_rhss_major_;
static constexpr auto quad_ps_to_rhss_minor = Policy::QuadInputEncoding::ps_to_rhss_minor_;
static constexpr auto I0 = number<0>{};
static constexpr auto quad_input_ps_to_rhss_major0 = QuadInputEncoding::ps_to_rhss_major_[I0];
static constexpr auto quad_input_ps_to_rhss_minor0 = QuadInputEncoding::ps_to_rhss_minor_[I0];
static constexpr auto quad_output_ps_to_rhss_major0 = QuadOutputEncoding::ps_to_rhss_major_[I0];
static constexpr auto quad_output_ps_to_rhss_minor0 = QuadOutputEncoding::ps_to_rhss_minor_[I0];
static constexpr auto quad_output_ys_to_rhs_major = QuadOutputEncoding::ys_to_rhs_major_;
static constexpr auto quad_output_ys_to_rhs_minor = QuadOutputEncoding::ys_to_rhs_minor_;
static constexpr index_t dim0 = Policy::transpose_dims[0];
static constexpr index_t dim1 = Policy::transpose_dims[1];
static constexpr auto swap_one_and_two = [](const index_t idx) {
return (idx == 1) ? 2 : (idx == 2) ? 1 : idx;
};
// for transpose load
// append the reversed quad output hs lengths to the input hs lengthss after removing
// the quad_input_hs_lengthss
// then reverse the whole sequence to get the dst_out_hs_lengthss
static constexpr auto reversed_quad_output_hs_lengthss = tuple_reverse(quad_output_hs_lengthss);
static constexpr auto full_out_hs_lengthss = generate_tuple(
// remove the quad_input_hs_lengthss from the input_hs_lengthss for each dimension and reverse
// dims and append the quad_output_hs_lengthss to the end of each dimension
static constexpr auto outer_hs_lengthss = generate_tuple(
[](auto i) {
return input_hs_lengthss[i]
.extract(typename arithmetic_sequence_gen<0,
input_hs_lengthss[i].size() -
quad_input_hs_lengthss[i].size(),
1>::type{})
.push_back(reversed_quad_output_hs_lengthss[i]);
constexpr auto input_i = input_hs_lengthss[i];
constexpr auto outer_len = input_i.size() - quad_input_hs_lengthss[i].size();
return typename sequence_split<decltype(input_i), outer_len>::left_type{};
},
number<InDstrEncode::NDimX>{});
static constexpr auto reversed_outer_hs_lengthss = tuple_reverse(outer_hs_lengthss);
static constexpr auto dst_out_hs_lengthss = generate_tuple(
[](auto i) {
auto outer_i = reversed_outer_hs_lengthss[i];
// append the reversed quad output hs lengths to the outer hs lengths
return outer_i.push_back(quad_output_hs_lengthss[i]);
},
number<InDstrEncode::NDimX>{});
static constexpr auto dst_out_hs_lengthss = tuple_reverse(full_out_hs_lengthss);
// for PS→RHS mapping(both major and minor), we need to modify the last element of the major
// sequence
static constexpr auto modified_ps_to_rhss_major = generate_tuple(
// for PS→RHS mapping(both major and minor), we need to modify the last element (which is for
// thread distr) of the major sequence
static constexpr auto dst_ps_to_rhss_major = generate_tuple(
// for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed
[](auto i) {
if constexpr(i == input_ps_to_rhss_major.size() - 1)
{
constexpr auto current_size = input_ps_to_rhss_major[i].size();
constexpr auto reduce_size = quad_ps_to_rhss_major[number<0>{}].size();
constexpr auto reduce_size = quad_input_ps_to_rhss_major0.size();
constexpr auto quad_out = quad_output_ps_to_rhss_major0;
constexpr auto reduced_ps_to_rhss_major = input_ps_to_rhss_major[i].extract(
typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{});
return reduced_ps_to_rhss_major.push_back(number<2>{});
return reduced_ps_to_rhss_major.transform(swap_one_and_two).push_back(quad_out);
}
else
{
// For all other sequences, keep them unchanged
return input_ps_to_rhss_major[i];
// For all other sequences (i.e. warp), keep them unchanged
return input_ps_to_rhss_major[i].transform(swap_one_and_two);
}
},
number<input_ps_to_rhss_major.size()>{});
static constexpr auto minor_last_index =
full_out_hs_lengthss[number<InDstrEncode::NDimX - 1>{}].size() - 1;
static constexpr auto major_last_index = full_out_hs_lengthss[number<0>{}].size() - 1;
static constexpr auto quad_idx_offset =
transform_tuples([](auto x) { return number<x.size()>{}; }, reversed_outer_hs_lengthss);
// minus 1 because RsLength is not counted
static constexpr auto quad_output_ps_minor_offset = to_sequence(generate_tuple_for(
[](auto x) { return quad_idx_offset[number<x - 1>{}]; }, quad_output_ps_to_rhss_major0));
static constexpr auto quad_output_ys_minor_offset = to_sequence(generate_tuple_for(
[](auto x) { return quad_idx_offset[number<x - 1>{}]; }, quad_output_ys_to_rhs_major));
static constexpr auto dst_ps_to_rhss_minor = generate_tuple(
[](auto i) {
constexpr auto input_i = input_ps_to_rhss_minor[i];
if constexpr(i == input_ps_to_rhss_minor.size() - 1)
{
constexpr auto current_size = input_ps_to_rhss_minor[i].size();
constexpr auto reduce_size = quad_ps_to_rhss_minor[number<0>{}].size();
constexpr auto reduced_ps_to_rhss_minor = input_ps_to_rhss_minor[i].extract(
typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{});
return reduced_ps_to_rhss_minor.push_back(number<minor_last_index>{});
constexpr auto outer_len = input_i.size() - quad_input_ps_to_rhss_minor0.size();
constexpr auto outer_ps =
typename sequence_split<decltype(input_i), outer_len>::left_type{};
return outer_ps.push_back(quad_output_ps_minor_offset +
quad_output_ps_to_rhss_minor0);
}
else
{
// For all other sequences, keep them unchanged
return input_ps_to_rhss_minor[i];
return input_i;
}
},
number<input_ps_to_rhss_minor.size()>{});
static constexpr auto outer_input_ys_to_rhs_major = input_ys_to_rhs_major.pop_back();
// for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed
static constexpr auto swap_one_and_two = [](const index_t idx) {
return (idx == 1) ? 2 : (idx == 2) ? 1 : idx;
};
static constexpr auto dst_ps_to_rhss_major = generate_tuple(
[](auto i) { return modified_ps_to_rhss_major[i].transform(swap_one_and_two); },
number<modified_ps_to_rhss_major.size()>{});
static constexpr auto dst_ys_to_rhs_major =
outer_input_ys_to_rhs_major.transform(swap_one_and_two).push_back(number<2>{});
static constexpr auto modified_input_ys_to_rhs_major =
input_ys_to_rhs_major.pop_back().push_back(number<1>{});
static constexpr auto dst_ys_to_rhs_minor = input_ys_to_rhs_minor.pop_back().push_back(
number<(quad_output_ys_minor_offset + quad_output_ys_to_rhs_minor)[I0]>{});
static constexpr auto dst_ys_to_rhs_major = generate_sequence_v2(
[](auto i) { return number<swap_one_and_two(modified_input_ys_to_rhs_major[i])>{}; },
number<modified_input_ys_to_rhs_major.size()>{});
static constexpr auto dst_ys_to_rhs_minor =
input_ys_to_rhs_minor.pop_back().push_back(number<major_last_index>{});
using OutDstrEncode = tile_distribution_encoding<typename InDstrEncode::RsLengths,
remove_cvref_t<decltype(dst_out_hs_lengthss)>,
remove_cvref_t<decltype(dst_ps_to_rhss_major)>,
remove_cvref_t<decltype(dst_ps_to_rhss_minor)>,
remove_cvref_t<decltype(dst_ys_to_rhs_major)>,
remove_cvref_t<decltype(dst_ys_to_rhs_minor)>>;
using TransposedDstrEncode =
tile_distribution_encoding<typename InDstrEncode::RsLengths,
remove_cvref_t<decltype(dst_out_hs_lengthss)>,
remove_cvref_t<decltype(dst_ps_to_rhss_major)>,
remove_cvref_t<decltype(dst_ps_to_rhss_minor)>,
remove_cvref_t<decltype(dst_ys_to_rhs_major)>,
remove_cvref_t<decltype(dst_ys_to_rhs_minor)>>;
};
template <typename TileDistributionEncoding_,
typename DataType_,
typename Policy = DefaultTranspose<DataType_>>
using OutputTileDistributionTraits =
TransposeTileDistributionTraits<TileDistributionEncoding_, DataType_, Policy, false>;
template <typename TileDistributionEncoding_,
typename DataType_,
typename Policy = DefaultTranspose<DataType_>>
using InputTileDistributionTraits =
TransposeTileDistributionTraits<TileDistributionEncoding_, DataType_, Policy, true>;
template <typename InnerEncode,
index_t kLeadIterPerWarp,
index_t kSecondIterPerWarp,
@@ -312,18 +396,18 @@ template <
index_t NumCoord,
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE auto
load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window)
{
using OutTileDstrEncode =
typename OutputTileDistributionTraits<TileDistribution_,
typename BottomTensorView_::DataType>::OutDstrEncode;
using OutTileDstrEncode = typename OutputTileDistributionTraits<
typename TileDistribution_::DstrEncode,
typename BottomTensorView_::DataType>::TransposedDstrEncode;
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
make_static_tile_distribution(OutTileDstrEncode{}));
auto trans_tensor = tile_window.template load_transpose<Policy>();

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -53,10 +53,13 @@ struct is_null_tile_window<null_tile_window<T>> : public std::true_type
};
} // namespace impl
template <typename T>
constexpr bool is_null_tile_window_v = impl::is_null_tile_window<remove_cvref_t<T>>::value;
template <typename T>
CK_TILE_DEVICE constexpr auto is_null_tile_window(const T&)
{
return impl::is_null_tile_window<remove_cvref_t<T>>::value;
return is_null_tile_window_v<remove_cvref_t<T>>;
}
template <typename WindowLengths>

View File

@@ -303,6 +303,6 @@ struct tile_sweeper
template <typename T,
typename F,
typename U = typename uniform_sequence_gen<T::get_num_of_dimension(), 1>::type>
CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {})->tile_sweeper<T, F, U>;
CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {}) -> tile_sweeper<T, F, U>;
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -81,7 +81,7 @@ struct tensor_adaptor
template <index_t IDimHidden>
CK_TILE_HOST_DEVICE static constexpr auto
get_transform_and_its_upper_dimension(number<IDimHidden>)
get_transform_and_its_upper_dimension(number<IDimHidden>)
{
// FIXME: length of bottom dimension is not known, since info about lower dim length are not
// saved in transformation
@@ -119,13 +119,13 @@ struct tensor_adaptor
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_hidden_dimension()
{
constexpr auto all_low_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
LowerDimensionHiddenIdss{});
constexpr auto all_low_dim_ids =
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
LowerDimensionHiddenIdss{});
constexpr auto all_up_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
UpperDimensionHiddenIdss{});
constexpr auto all_up_dim_ids =
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
UpperDimensionHiddenIdss{});
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
@@ -259,6 +259,7 @@ struct tensor_adaptor
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
template <index_t Internal = 0>
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides(
const array<index_t, ndim_hidden_>& guaranteed_vector_lengths,
const array<index_t, ndim_hidden_>& guaranteed_vector_strides)
@@ -266,7 +267,9 @@ struct tensor_adaptor
auto vector_lengths = guaranteed_vector_lengths;
auto vector_strides = guaranteed_vector_strides;
static_for<0, get_num_of_transform(), 1>{}([&](auto itran) {
static_for<0,
Internal ? std::min(Internal, get_num_of_transform()) : get_num_of_transform(),
1>{}([&](auto itran) {
constexpr auto low_dims = get_lower_dimension_hidden_idss().at(itran);
constexpr auto up_dims = get_upper_dimension_hidden_idss().at(itran);
@@ -298,42 +301,16 @@ struct tensor_adaptor
set_container_subset(vector_lengths, up_dims, up_vector_lengths);
set_container_subset(vector_strides, up_dims, up_vector_strides);
});
constexpr auto top_dims = TopDimensionHiddenIds{};
return make_tuple(get_container_subset(vector_lengths, top_dims),
get_container_subset(vector_strides, top_dims));
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_adaptor{");
//
printf("transforms: ");
print(transforms_);
printf(", ");
//
printf("LowerDimensionHiddenIds: ");
print(LowerDimensionHiddenIdss{});
printf(", ");
//
printf("UpperDimensionHiddenIds: ");
print(UpperDimensionHiddenIdss{});
printf(", ");
//
printf("BottomDimensionHiddenIds: ");
print(BottomDimensionHiddenIds{});
printf(", ");
//
printf("TopDimensionHiddenIds: ");
print(TopDimensionHiddenIds{});
printf("}");
if constexpr(Internal > 0)
{
return make_tuple(vector_lengths, vector_strides);
}
else
{
constexpr auto top_dims = TopDimensionHiddenIds{};
return make_tuple(get_container_subset(vector_lengths, top_dims),
get_container_subset(vector_strides, top_dims));
}
}
private:
@@ -341,6 +318,40 @@ struct tensor_adaptor
ElementSize element_size_;
};
template <typename Transforms,
typename LowerDimensionHiddenIdss,
typename UpperDimensionHiddenIdss,
typename BottomDimensionHiddenIds,
typename TopDimensionHiddenIds>
CK_TILE_HOST_DEVICE static void print(const tensor_adaptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
BottomDimensionHiddenIds,
TopDimensionHiddenIds>& adaptor)
{
printf("tensor_adaptor{\n");
printf(" transforms: [");
print(adaptor.get_transforms());
printf("],\n");
printf(" LowerDimensionHiddenIds: [");
print(LowerDimensionHiddenIdss{});
printf("],\n");
printf(" UpperDimensionHiddenIds: [");
print(UpperDimensionHiddenIdss{});
printf("],\n");
printf(" BottomDimensionHiddenIds: [");
print(BottomDimensionHiddenIds{});
printf("],\n");
//
printf(" TopDimensionHiddenIds: [");
print(TopDimensionHiddenIds{});
printf("]\n}\n");
}
// Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
@@ -461,7 +472,7 @@ transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, plus<index_t>{}, number<0>{}));
constexpr auto up_dim_hidden_idss = generate_tuple(
[ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr {
[old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr {
return
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
@@ -470,8 +481,8 @@ transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
number<num_new_transform>{});
// new top dimension's hidden ids
constexpr auto unordered_new_top_dim_hidden_ids = unpack(
[](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
constexpr auto unordered_new_top_dim_hidden_ids =
unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
constexpr auto new_top_dim_unordered2ordered = unpack(
[](auto... xs) constexpr { return merge_sequences(xs...); }, NewUpperDimensionNewTopIdss{});
@@ -595,8 +606,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran];
// sequence in, sequence out
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr
{
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr {
auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1);
// shift hidden id so every dim id is unique
@@ -619,8 +629,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
});
return low_dim_hidden_ids_1_mod_;
}
();
}();
return generate_sequence_v2(
[&](auto i) constexpr { return number<low_dim_hidden_ids_1_mod[i]>{}; },
@@ -643,8 +652,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran];
// sequence in, constexpr tuple out
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr
{
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr {
auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1);
// shift hidden id
@@ -653,8 +661,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
});
return up_dim_hidden_ids_1_mod_;
}
();
}();
// constexpr tuple to sequence
return generate_sequence_v2(

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -133,32 +133,51 @@ struct tensor_descriptor : public tensor_adaptor<Transforms,
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
template <index_t Internal = 0>
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides()
{
return Base::get_top_dimension_safe_vector_length_strides(
return Base::template get_top_dimension_safe_vector_length_strides<Internal>(
to_array<index_t, ndim_hidden_>(GuaranteedVectorLengths{}),
to_array<index_t, ndim_hidden_>(GuaranteedVectorStrides{}));
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_descriptor{");
// tensor_adaptor
Base::print();
printf(", ");
// element_space_size_
printf("element_space_size_: ");
print(element_space_size_);
printf("}");
}
// TODO make these private
ElementSpaceSize element_space_size_;
};
template <typename Transforms,
typename LowerDimensionHiddenIdss,
typename UpperDimensionHiddenIdss,
typename TopDimensionHiddenIds,
typename ElementSpaceSize,
typename GuaranteedVectorLengths,
typename GuaranteedVectorStrides>
CK_TILE_HOST_DEVICE static void print(const tensor_descriptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
TopDimensionHiddenIds,
ElementSpaceSize,
GuaranteedVectorLengths,
GuaranteedVectorStrides>& descriptor)
{
printf("tensor_descriptor{\n");
// first print the tensor adaptor part of the descriptor using the base class print
using Base = typename tensor_descriptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
TopDimensionHiddenIds,
ElementSpaceSize,
GuaranteedVectorLengths,
GuaranteedVectorStrides>::Base;
print(static_cast<const Base&>(descriptor));
printf("element_space_size_: %ld,\n", static_cast<long>(descriptor.get_element_space_size()));
printf("guaranteed_vector_lengths: ");
print(GuaranteedVectorLengths{});
printf(",\nguaranteed_vector_strides: ");
print(GuaranteedVectorStrides{});
printf("}\n}\n");
}
template <typename Adaptor, typename ElementSpaceSize>
CK_TILE_HOST_DEVICE constexpr auto
make_tensor_descriptor_from_adaptor(const Adaptor& adaptor,
@@ -365,12 +384,29 @@ make_naive_tensor_descriptor_packed(const tuple<Lengths...>& lengths,
const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{});
constexpr index_t first_dim_length = []() {
if constexpr(is_constant_v<remove_cvref_t<decltype(element_space_size)>>)
return decltype(element_space_size)::value;
else
return -1;
}();
using last_t = remove_cvref_t<decltype(lengths.template get<N - 1>())>;
constexpr index_t last_dim_length = []() {
if constexpr(is_constant_v<last_t>)
return std::max(last_t::value, GuaranteedLastDimensionVectorLength);
else
return -1;
}();
using GuaranteedVectorLengths =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
sequence<GuaranteedLastDimensionVectorLength>>::type;
typename sequence_merge<sequence<first_dim_length>,
typename uniform_sequence_gen<N - 1, -1>::type,
sequence<last_dim_length>>::type;
using GuaranteedVectorStrides =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type, sequence<1>>::type;
typename sequence_merge<sequence<1>,
typename uniform_sequence_gen<N - 1, -1>::type,
sequence<1>>::type;
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,

View File

@@ -445,10 +445,11 @@ struct null_tensor_view
};
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
memory_operation_enum DstInMemOp = memory_operation_enum::set,
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
typename DataType,
typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* __restrict__ p,
const tensor_descriptor<Ts...>& desc)
{
auto buffer_view =
@@ -467,7 +468,7 @@ template <address_space_enum BufferAddressSpace = address_space_enum::generic,
index_t GuaranteedLastDimensionVectorStride = -1,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_view(DataType* p,
make_naive_tensor_view(DataType* __restrict__ p,
const tuple<Lengths...>& lengths,
const tuple<Strides...>& strides,
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
@@ -490,7 +491,7 @@ template <address_space_enum BufferAddressSpace = address_space_enum::generic,
typename... Lengths,
index_t GuaranteedLastDimensionVectorLength = -1>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_view_packed(DataType* p,
make_naive_tensor_view_packed(DataType* __restrict__ p,
const tuple<Lengths...>& lengths,
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
{

View File

@@ -202,7 +202,7 @@ struct tile_distribution
// FIXME: it's hacky to get Y index from Distributed-Index
template <typename DistributedIndices>
CK_TILE_HOST_DEVICE static constexpr auto
get_y_indices_from_distributed_indices(DistributedIndices)
get_y_indices_from_distributed_indices(DistributedIndices)
{
constexpr auto ys_idx_arr = [] {
array<index_t, NDimY> ys_idx;
@@ -228,24 +228,6 @@ struct tile_distribution
{
return PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static();
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution{");
//
printf("tile_distribution_encoding: ");
print(DstrEncode{});
printf(", ");
//
printf("ps_ys_to_xs_: ");
print(ps_ys_to_xs_);
printf(", ");
//
printf("ys_to_d_: ");
print(ys_to_d_);
//
printf("}");
}
};
namespace detail {
@@ -266,7 +248,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_sequential_index(index_t ibegin, index_t
// this returns a constexpr encoding of tile_distribution
template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto
make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
{
using RsLengths = typename StaticTileDistributionEncoding_::RsLengths;
using HsLengthss = typename StaticTileDistributionEncoding_::HsLengthss;
@@ -614,8 +596,7 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
constexpr auto src_y_maps = src_y_info[number<1>{}];
constexpr auto src_y_prefix_sum = src_y_info[number<2>{}];
constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr
{
constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr {
auto y_slice_sorted_origins = make_zero_multi_index<Encoding::NDimY>();
auto y_slice_lengths = Encoding::detail::ys_lengths_;
constexpr auto y_to_h_masks = Encoding::detail::get_y_to_h_masks();
@@ -685,8 +666,7 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps);
return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths);
}
();
}();
constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[number<0>{}];
constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[number<1>{}];
@@ -712,4 +692,27 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
}
} // namespace detail
// Free print function for tile_distribution
template <typename PsYs2XsAdaptor_,
typename Ys2DDescriptor_,
typename StaticTileDistributionEncoding_,
typename TileDistributionDetail_>
CK_TILE_HOST_DEVICE void print(const tile_distribution<PsYs2XsAdaptor_,
Ys2DDescriptor_,
StaticTileDistributionEncoding_,
TileDistributionDetail_>& distribution)
{
printf("tile_distribution{");
printf("tile_distribution_encoding: ");
print(StaticTileDistributionEncoding_{});
printf(", ");
printf("ps_ys_to_xs_: ");
print(distribution.ps_ys_to_xs_);
printf(", ");
printf("ys_to_d_: ");
print(distribution.ys_to_d_);
printf("}\n");
}
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -428,111 +428,29 @@ struct tile_distribution_encoding
{
return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum());
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution_encoding::detail{");
//
printf("ndim_rh_major_: ");
print(ndim_rh_major_);
printf(", ");
//
printf("ndim_span_major_: ");
print(ndim_span_major_);
printf(", ");
//
printf("ndims_rhs_minor_: ");
print(ndims_rhs_minor_);
printf(", ");
//
printf("ndim_rh_major_: ");
print(ndim_rh_major_);
printf(", ");
//
printf("max_ndim_rh_minor_: ");
print(max_ndim_rh_minor_);
printf(", ");
//
printf("rhs_lengthss_: ");
print(rhs_lengthss_);
printf(", ");
//
printf("ys_lengths_: ");
print(ys_lengths_);
printf(", ");
//
printf("rhs_major_minor_to_ys_: ");
print(rhs_major_minor_to_ys_);
printf(", ");
//
printf("ndims_span_minor_: ");
print(ndims_span_minor_);
printf(", ");
//
printf("max_ndim_span_minor_: ");
print(max_ndim_span_minor_);
printf(", ");
//
printf("ys_to_span_major_: ");
print(ys_to_span_major_);
printf(", ");
//
printf("ys_to_span_minor_: ");
print(ys_to_span_minor_);
printf(", ");
//
printf("distributed_spans_lengthss_: ");
print(distributed_spans_lengthss_);
printf(", ");
//
printf("ndims_distributed_spans_minor_: ");
print(ndims_distributed_spans_minor_);
printf(", ");
//
printf("ps_over_rs_derivative_: ");
print(ps_over_rs_derivative_);
//
printf("}");
}
};
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution_encoding{");
//
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", NDimX, NDimP, NDimY);
//
printf("rs_lengths_: ");
print(rs_lengths_);
printf(", ");
//
printf("hs_lengthss_: ");
print(hs_lengthss_);
printf(", ");
//
printf("ps_to_rhss_major_: ");
print(ps_to_rhss_major_);
printf(", ");
//
printf("ps_to_rhss_minor_: ");
print(ps_to_rhss_minor_);
printf(", ");
//
printf("ys_to_rhs_major_: ");
print(ys_to_rhs_major_);
printf(", ");
//
printf("ys_to_rhs_minor_: ");
print(ys_to_rhs_minor_);
printf(", ");
//
printf("detail: ");
print(detail{});
//
printf("}");
}
};
template <typename encoding, typename shuffle>
class tile_distribution_encoding_shuffle;
template <typename encoding, index_t... shuffle>
class tile_distribution_encoding_shuffle<encoding, sequence<shuffle...>>
{
template <typename Ys2RHs>
using shuffled = sequence<(Ys2RHs::template get<shuffle>())...>;
public:
using type = tile_distribution_encoding<typename encoding::RsLengths,
typename encoding::HsLengthss,
typename encoding::Ps2RHssMajor,
typename encoding::Ps2RHssMinor,
shuffled<typename encoding::Ys2RHsMajor>,
shuffled<typename encoding::Ys2RHsMinor>>;
};
template <typename encoding, typename shuffle>
using tile_distribution_encoding_shuffle_t =
typename tile_distribution_encoding_shuffle<encoding, shuffle>::type;
namespace detail {
template <typename OuterDstr, typename InnerDstr>
@@ -876,4 +794,106 @@ make_reduce_tile_distribution_encoding(InDstr, sequence<InReduceDimXs...> reduce
}
} // namespace detail
// Free print function for tile_distribution_encoding::detail
template <typename RsLengths_,
typename HsLengthss_,
typename Ps2RHssMajor_,
typename Ps2RHssMinor_,
typename Ys2RHsMajor_,
typename Ys2RHsMinor_>
CK_TILE_HOST_DEVICE void
print(const typename tile_distribution_encoding<RsLengths_,
HsLengthss_,
Ps2RHssMajor_,
Ps2RHssMinor_,
Ys2RHsMajor_,
Ys2RHsMinor_>::detail& detail_obj)
{
printf("tile_distribution_encoding::detail{");
printf("ndim_rh_major_: ");
print(detail_obj.ndim_rh_major_);
printf(", ");
printf("ndim_span_major_: ");
print(detail_obj.ndim_span_major_);
printf(", ");
printf("ndims_rhs_minor_: ");
print(detail_obj.ndims_rhs_minor_);
printf(", ");
printf("ndim_rh_major_: ");
print(detail_obj.ndim_rh_major_);
printf(", ");
printf("max_ndim_rh_minor_: ");
print(detail_obj.max_ndim_rh_minor_);
printf(", ");
printf("rhs_lengthss_: ");
print(detail_obj.rhs_lengthss_);
printf(", ");
printf("ys_lengths_: ");
print(detail_obj.ys_lengths_);
printf(", ");
printf("rhs_major_minor_to_ys_: ");
print(detail_obj.rhs_major_minor_to_ys_);
printf(", ");
printf("ndims_span_minor_: ");
print(detail_obj.ndims_span_minor_);
printf(", ");
printf("max_ndim_span_minor_: ");
print(detail_obj.max_ndim_span_minor_);
printf(", ");
printf("ys_to_span_major_: ");
print(detail_obj.ys_to_span_major_);
printf(", ");
printf("ys_to_span_minor_: ");
print(detail_obj.ys_to_span_minor_);
printf(", ");
printf("distributed_spans_lengthss_: ");
print(detail_obj.distributed_spans_lengthss_);
printf(", ");
printf("ndims_distributed_spans_minor_: ");
print(detail_obj.ndims_distributed_spans_minor_);
printf(", ");
printf("ps_over_rs_derivative_: ");
print(detail_obj.ps_over_rs_derivative_);
printf("}");
}
// Free print function for tile_distribution_encoding
template <typename RsLengths_,
typename HsLengthss_,
typename Ps2RHssMajor_,
typename Ps2RHssMinor_,
typename Ys2RHsMajor_,
typename Ys2RHsMinor_>
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding<RsLengths_,
HsLengthss_,
Ps2RHssMajor_,
Ps2RHssMinor_,
Ys2RHsMajor_,
Ys2RHsMinor_>& encoding)
{
printf("tile_distribution_encoding{");
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", encoding.NDimX, encoding.NDimP, encoding.NDimY);
printf("rs_lengths_: ");
print(encoding.rs_lengths_);
printf(", ");
printf("hs_lengthss_: ");
print(encoding.hs_lengthss_);
printf(", ");
printf("ps_to_rhss_major_: ");
print(encoding.ps_to_rhss_major_);
printf(", ");
printf("ps_to_rhss_minor_: ");
print(encoding.ps_to_rhss_minor_);
printf(", ");
printf("ys_to_rhs_major_: ");
print(encoding.ys_to_rhs_major_);
printf(", ");
printf("ys_to_rhs_minor_: ");
print(encoding.ys_to_rhs_minor_);
printf(", ");
printf("}");
}
} // namespace ck_tile

View File

@@ -327,9 +327,8 @@ CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors)
template <typename DstType, typename SrcTensor>
CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
{
if constexpr((std::is_same_v<DstType, fp8_t> ||
std::is_same_v<DstType, bf8_t>)&&std::is_same_v<typename SrcTensor::DataType,
float> &&
if constexpr((std::is_same_v<DstType, fp8_t> || std::is_same_v<DstType, bf8_t>) &&
std::is_same_v<typename SrcTensor::DataType, float> &&
(SrcTensor::get_thread_buffer_size() % 4 == 0))
{
return impl::cast_tile_pk_fp8_fp32<DstType, SrcTensor>(src_tensor);

View File

@@ -288,8 +288,11 @@ struct tile_window_with_static_distribution
sizeof(LdsDataType) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
m0_set_with_memory(m0_init_value); // This should be wave independent
// Use VALU so the compiler can optimize redundant/repeated computations
const index_t m0_init_value =
size_per_buf + size_per_wave * get_warp_id(/*ReturnSgpr=*/bool_constant<false>{});
m0_set_with_memory(
__builtin_amdgcn_readfirstlane(m0_init_value)); // This should be wave independent
using Traits = typename Base::Traits;

View File

@@ -74,8 +74,9 @@ struct tile_window_linear
static constexpr auto get_num_non_linear_access()
{
constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths;
using ys_to_rhs_major = typename decltype(
typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
using ys_to_rhs_major =
typename decltype(typename Base::TileDstr{}
.get_static_tile_distribution_encoding())::Ys2RHsMajor;
constexpr auto non_linear = [&]() {
index_t cnt = 1;
@@ -109,8 +110,9 @@ struct tile_window_linear
static constexpr auto get_non_linear_access_map()
{
constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths;
using ys_to_rhs_major = typename decltype(
typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
using ys_to_rhs_major =
typename decltype(typename Base::TileDstr{}
.get_static_tile_distribution_encoding())::Ys2RHsMajor;
constexpr auto non_linear_map = [&]() {
array<index_t, Base::Traits::NumAccess> m_{0};
index_t cumulative_len_ = 1;
@@ -244,8 +246,9 @@ struct tile_window_linear
{
using SFC_Ys = typename Base::Traits::SFC_Ys;
constexpr auto idx_ys = SFC_Ys::get_index(number<i_access>{});
using ys_to_rhs_major = typename decltype(
typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
using ys_to_rhs_major =
typename decltype(typename Base::TileDstr{}
.get_static_tile_distribution_encoding())::Ys2RHsMajor;
constexpr auto modified_idx_ys = generate_tuple(
[&](auto i_dim_y) {

View File

@@ -0,0 +1,156 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdio.h>
#include <tuple>
#include <utility>
#include "ck_tile/core/numeric/integer.hpp"
namespace ck_tile {
template <auto... val>
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
{
}
template <typename... type>
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
{
}
template <char... Xs>
struct str_literal
{
static constexpr const char data[] = {Xs..., '\0'};
static constexpr const size_t size = sizeof...(Xs);
template <char... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator+(str_literal<Ys...> /*rhs*/) const
{
return str_literal<Xs..., Ys...>{};
}
template <index_t N, char... Ys>
CK_TILE_HOST_DEVICE static constexpr auto duplicate_n(const str_literal<Ys...> sep)
{
if constexpr(N == 0)
return str_literal<>{};
else if constexpr(N == 1)
return str_literal<Xs...>{};
else
return duplicate_n<N - 1>(sep) + str_literal<Ys..., Xs...>{};
}
};
#define make_str_literal(lit_) \
std::apply([](auto... indices) { return str_literal<(lit_)[decltype(indices)::value]...>{}; }, \
makeTuple(std::make_index_sequence<constexpr_strlen(lit_)>()))
template <size_t... Idx>
constexpr std::tuple<std::integral_constant<size_t, Idx>...>
makeTuple(std::index_sequence<Idx...>) noexcept
{
return {};
}
constexpr size_t constexpr_strlen(const char* c)
{
size_t t = 0;
while(*c++)
++t;
return t;
}
template <typename DataType_, typename StaticTileDistribution_>
struct static_distributed_tensor;
template <typename T_, index_t N_>
struct thread_buffer;
// Usage example: CK_PRINTF<float>{}(tensor);
template <typename ConvertTo = void,
typename FMT = str_literal<>,
typename PREFIX = str_literal<>,
typename SUFFIX = str_literal<>>
struct CK_PRINTF;
template <typename ConvertTo, char... FMTChars, char... PREFIXChars, char... SUFFIXChars>
struct CK_PRINTF<ConvertTo,
str_literal<FMTChars...>,
str_literal<PREFIXChars...>,
str_literal<SUFFIXChars...>>
{
template <typename T>
CK_TILE_HOST_DEVICE static constexpr auto default_format()
{
if constexpr(std::is_same_v<T, float>)
return make_str_literal("%8.3f");
else if constexpr(std::is_same_v<T, int>)
return make_str_literal("%5d");
else if constexpr(std::is_same_v<T, unsigned int>)
return make_str_literal("%5u");
else
return make_str_literal("0x%08x");
}
CK_TILE_HOST_DEVICE static constexpr auto get_prefix()
{
constexpr auto fmt_tid = make_str_literal("tid %03d: [%02d] ");
if constexpr(sizeof...(PREFIXChars) == 0)
return fmt_tid;
else
return fmt_tid + make_str_literal(" ") + str_literal<PREFIXChars...>{};
}
CK_TILE_HOST_DEVICE static constexpr auto get_suffix()
{
constexpr auto lf = make_str_literal("\n");
if constexpr(sizeof...(SUFFIXChars) == 0)
return lf;
else
return str_literal<SUFFIXChars...>{} + lf;
}
template <typename T, index_t N, typename Y, index_t... Is>
CK_TILE_HOST_DEVICE void impl(const thread_buffer<T, N>& buf,
std::integer_sequence<index_t, Is...>) const
{
using FMT1 = std::conditional_t<sizeof...(FMTChars) == 0,
decltype(default_format<Y>()),
str_literal<FMTChars...>>;
constexpr auto fmt_v = FMT1::template duplicate_n<N>(make_str_literal(" "));
constexpr auto fmt_wrap_v = get_prefix() + fmt_v + get_suffix();
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
printf(fmt_wrap_v.data, get_thread_id(), N, type_convert<Y>(buf[Is])...);
#pragma clang diagnostic pop
}
template <typename T, index_t N>
CK_TILE_HOST_DEVICE void operator()(const thread_buffer<T, N>& buf) const
{
using ConvertTo_ = std::conditional_t<std::is_same_v<ConvertTo, void>, T, ConvertTo>;
impl<T, N, ConvertTo_>(buf, std::make_integer_sequence<index_t, N>{});
}
template <typename... TS>
CK_TILE_HOST_DEVICE void operator()(const static_distributed_tensor<TS...>& tensor) const
{
return operator()(tensor.get_thread_buffer());
}
};
template <typename ConvertTo = void,
typename FMT = str_literal<>,
typename PREFIX = str_literal<>,
typename SUFFIX = str_literal<>>
struct CK_PRINTF_WARP0 : public CK_PRINTF<ConvertTo, FMT, PREFIX, SUFFIX>
{
using base_t = CK_PRINTF<ConvertTo, FMT, PREFIX, SUFFIX>;
template <typename T>
CK_TILE_HOST_DEVICE void operator()(const T& buf) const
{
if(get_thread_id() < get_warp_size())
base_t::operator()(buf);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,76 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
/// Declare a ck_tile::print() interface that gets specialized in each header file for types that
/// can be printed.
template <typename T>
CK_TILE_HOST_DEVICE void print(const T&)
{
static_assert(sizeof(T) == 0,
"No print implementation available for this type. Please specialize "
"ck_tile::print for your type.");
}
/// Specialization for int
template <>
CK_TILE_HOST_DEVICE void print(const int& value)
{
printf("%d", value);
}
/// Specialization for float
template <>
CK_TILE_HOST_DEVICE void print(const float& value)
{
printf("%f", value);
}
/// Specialization for double
template <>
CK_TILE_HOST_DEVICE void print(const double& value)
{
printf("%f", value);
}
/// Specialization for long
template <>
CK_TILE_HOST_DEVICE void print(const long& value)
{
printf("%ld", value);
}
/// Specialization for unsigned int
template <>
CK_TILE_HOST_DEVICE void print(const unsigned int& value)
{
printf("%u", value);
}
/// Specialization for char
template <>
CK_TILE_HOST_DEVICE void print(const char& value)
{
printf("%c", value);
}
/// Specialization for array
template <typename T, size_t N>
CK_TILE_HOST_DEVICE void print(const T (&value)[N])
{
printf("[");
for(size_t i = 0; i < N; ++i)
{
if(i > 0)
printf(", ");
print(value[i]); // Recursively call print for each element
}
printf("]");
}
} // namespace ck_tile

View File

@@ -26,7 +26,8 @@ struct Add
}
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t>>>
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
{
float y_ = type_convert<float>(y);
@@ -51,13 +52,25 @@ struct SquareAdd
{
return y + (x * x);
}
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
{
float y_ = type_convert<float>(y);
float x_ = type_convert<float>(x);
return type_convert<T>(y_ + (x_ * x_));
}
};
struct Max
{
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
{
return numeric<T>::min();
@@ -65,7 +78,9 @@ struct Max
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return max(y, x);
@@ -76,7 +91,9 @@ struct AbsMax
{
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
{
return numeric<T>::min();
@@ -84,7 +101,9 @@ struct AbsMax
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return max(y, abs(x));

View File

@@ -58,8 +58,8 @@ struct detector<Default, std::void_t<Op<Args...>>, Op, Args...>
struct nonesuch
{
~nonesuch() = delete;
nonesuch(nonesuch const&) = delete;
~nonesuch() = delete;
nonesuch(nonesuch const&) = delete;
void operator=(nonesuch const&) = delete;
};

View File

@@ -49,7 +49,7 @@ struct composes<F>
/// FIXME: create macro to replace '__host__ __device__' and nothing more
template <typename... Ts>
__host__ __device__ composes(Ts&&...)->composes<remove_cvref_t<Ts>...>;
__host__ __device__ composes(Ts&&...) -> composes<remove_cvref_t<Ts>...>;
template <typename SaturateType>
struct saturates
@@ -57,8 +57,8 @@ struct saturates
// NOTE: this function does not return SaturateType value
// it is user's responsiblity to do further cast or not
template <typename AccType>
CK_TILE_HOST_DEVICE constexpr auto operator()(const AccType& a_) const
-> std::enable_if_t<std::is_arithmetic_v<AccType>, AccType>
CK_TILE_HOST_DEVICE constexpr auto
operator()(const AccType& a_) const -> std::enable_if_t<std::is_arithmetic_v<AccType>, AccType>
{
return clamp(a_,
type_convert<AccType>(numeric<SaturateType>::lowest()),