mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
refactor: use snake_case naming in ck_tile/core components (#2766)
This commit is contained in:
@@ -104,7 +104,7 @@ enum struct tile_distribution_pattern
|
||||
block_raked,
|
||||
};
|
||||
|
||||
struct TileDistributionEncodingPattern
|
||||
struct tile_distribution_encoding_pattern
|
||||
{
|
||||
};
|
||||
|
||||
@@ -126,7 +126,7 @@ template <index_t BlockSize,
|
||||
index_t VecSize,
|
||||
tile_distribution_pattern DistributionPattern,
|
||||
index_t NumWaveGroups = 1>
|
||||
struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern
|
||||
struct tile_distribution_encoding_pattern_2d : public tile_distribution_encoding_pattern
|
||||
{
|
||||
};
|
||||
|
||||
@@ -136,12 +136,13 @@ template <index_t BlockSize,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
index_t NumWaveGroups>
|
||||
struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::thread_raked,
|
||||
NumWaveGroups> : public TileDistributionEncodingPattern
|
||||
struct tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::thread_raked,
|
||||
NumWaveGroups>
|
||||
: public tile_distribution_encoding_pattern
|
||||
{
|
||||
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
@@ -165,7 +166,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
"X0 * warp_ys * Y0 must cover whole workgroup!");
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
if constexpr(NumWaveGroups != 1)
|
||||
{
|
||||
@@ -189,7 +190,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_shuffled_2d_static_tile_distribution()
|
||||
{
|
||||
if constexpr(NumWaveGroups != 1)
|
||||
{
|
||||
@@ -220,12 +221,13 @@ template <index_t BlockSize,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
index_t NumWaveGroups>
|
||||
struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::warp_raked,
|
||||
NumWaveGroups> : public TileDistributionEncodingPattern
|
||||
struct tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::warp_raked,
|
||||
NumWaveGroups>
|
||||
: public tile_distribution_encoding_pattern
|
||||
{
|
||||
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
@@ -244,7 +246,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
static constexpr index_t Y1 = YPerTile / (Y2 * Y0); // # of iters within wavefront
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
@@ -255,7 +257,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
sequence<1, 1>>{}); // -> <Y1, X1>
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_shuffled_2d_static_tile_distribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
@@ -273,12 +275,13 @@ template <index_t BlockSize,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
index_t NumWaveGroups>
|
||||
struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::block_raked,
|
||||
NumWaveGroups> : public TileDistributionEncodingPattern
|
||||
struct tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::block_raked,
|
||||
NumWaveGroups>
|
||||
: public tile_distribution_encoding_pattern
|
||||
{
|
||||
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
@@ -295,7 +298,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
static constexpr index_t Y0 = YPerTile / (Y2 * Y1); // # of iters
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
@@ -306,7 +309,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
sequence<0, 1>>{}); // -> <Y0, X1>
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_shuffled_2d_static_tile_distribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
@@ -336,21 +339,21 @@ template <index_t BlockSize,
|
||||
index_t VecSize,
|
||||
tile_distribution_pattern DistributionPattern,
|
||||
index_t NumWaveGroups>
|
||||
CK_TILE_HOST_DEVICE void print(const TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
DistributionPattern,
|
||||
NumWaveGroups>&)
|
||||
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
DistributionPattern,
|
||||
NumWaveGroups>&)
|
||||
{
|
||||
using PatternType = TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
DistributionPattern,
|
||||
NumWaveGroups>;
|
||||
using PatternType = tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
DistributionPattern,
|
||||
NumWaveGroups>;
|
||||
|
||||
printf("TileDistributionEncodingPattern2D<BlockSize:%d, YPerTile:%d, XPerTile:%d, "
|
||||
printf("tile_distribution_encoding_pattern_2d<BlockSize:%d, YPerTile:%d, XPerTile:%d, "
|
||||
"VecSize:%d, %s>: ",
|
||||
BlockSize,
|
||||
YPerTile,
|
||||
|
||||
Reference in New Issue
Block a user