refactor: use snake_case naming in ck_tile/core components (#2766)

This commit is contained in:
msaffari-amd
2025-09-03 09:34:11 +02:00
committed by GitHub
parent 4d041837ad
commit 47d020a993
11 changed files with 171 additions and 155 deletions

View File

@@ -104,7 +104,7 @@ enum struct tile_distribution_pattern
block_raked,
};
struct TileDistributionEncodingPattern
struct tile_distribution_encoding_pattern
{
};
@@ -126,7 +126,7 @@ template <index_t BlockSize,
index_t VecSize,
tile_distribution_pattern DistributionPattern,
index_t NumWaveGroups = 1>
struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern
struct tile_distribution_encoding_pattern_2d : public tile_distribution_encoding_pattern
{
};
@@ -136,12 +136,13 @@ template <index_t BlockSize,
index_t XPerTile,
index_t VecSize,
index_t NumWaveGroups>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::thread_raked,
NumWaveGroups> : public TileDistributionEncodingPattern
struct tile_distribution_encoding_pattern_2d<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::thread_raked,
NumWaveGroups>
: public tile_distribution_encoding_pattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
@@ -165,7 +166,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
"X0 * warp_ys * Y0 must cover whole workgroup!");
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
if constexpr(NumWaveGroups != 1)
{
@@ -189,7 +190,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
}
}
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto make_shuffled_2d_static_tile_distribution()
{
if constexpr(NumWaveGroups != 1)
{
@@ -220,12 +221,13 @@ template <index_t BlockSize,
index_t XPerTile,
index_t VecSize,
index_t NumWaveGroups>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::warp_raked,
NumWaveGroups> : public TileDistributionEncodingPattern
struct tile_distribution_encoding_pattern_2d<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::warp_raked,
NumWaveGroups>
: public tile_distribution_encoding_pattern
{
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
@@ -244,7 +246,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
static constexpr index_t Y1 = YPerTile / (Y2 * Y0); // # of iters within wavefront
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
@@ -255,7 +257,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
sequence<1, 1>>{}); // -> <Y1, X1>
}
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto make_shuffled_2d_static_tile_distribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
@@ -273,12 +275,13 @@ template <index_t BlockSize,
index_t XPerTile,
index_t VecSize,
index_t NumWaveGroups>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::block_raked,
NumWaveGroups> : public TileDistributionEncodingPattern
struct tile_distribution_encoding_pattern_2d<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::block_raked,
NumWaveGroups>
: public tile_distribution_encoding_pattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
@@ -295,7 +298,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
static constexpr index_t Y0 = YPerTile / (Y2 * Y1); // # of iters
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
@@ -306,7 +309,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
sequence<0, 1>>{}); // -> <Y0, X1>
}
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto make_shuffled_2d_static_tile_distribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
@@ -336,21 +339,21 @@ template <index_t BlockSize,
index_t VecSize,
tile_distribution_pattern DistributionPattern,
index_t NumWaveGroups>
CK_TILE_HOST_DEVICE void print(const TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
DistributionPattern,
NumWaveGroups>&)
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding_pattern_2d<BlockSize,
YPerTile,
XPerTile,
VecSize,
DistributionPattern,
NumWaveGroups>&)
{
using PatternType = TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
DistributionPattern,
NumWaveGroups>;
using PatternType = tile_distribution_encoding_pattern_2d<BlockSize,
YPerTile,
XPerTile,
VecSize,
DistributionPattern,
NumWaveGroups>;
printf("TileDistributionEncodingPattern2D<BlockSize:%d, YPerTile:%d, XPerTile:%d, "
printf("tile_distribution_encoding_pattern_2d<BlockSize:%d, YPerTile:%d, XPerTile:%d, "
"VecSize:%d, %s>: ",
BlockSize,
YPerTile,