refactor: use snake_case naming in ck_tile/core components (#2766)

This commit is contained in:
msaffari-amd
2025-09-03 09:34:11 +02:00
committed by GitHub
parent 4d041837ad
commit 47d020a993
11 changed files with 171 additions and 155 deletions

View File

@@ -168,11 +168,11 @@ struct UniversalGemmBasePolicy
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern>;
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern>;
constexpr auto BK0 = number<TileEncodingPattern::X1>{};
constexpr auto BK1 = number<TileEncodingPattern::Y0>{};
@@ -494,24 +494,24 @@ struct UniversalGemmBasePolicy
// Tile: MPerBlock X KPerBlock
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
MPerBlock,
KPerBlock,
VecLoadSize,
ATileAccessPattern,
NumWaveGroups>;
return TileEncodingPattern::Make2DStaticTileDistribution();
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
MPerBlock,
KPerBlock,
VecLoadSize,
ATileAccessPattern,
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
// Tile: KPerBlock X MPerBlock
else
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
ATileAccessPattern,
NumWaveGroups>;
return TileEncodingPattern::Make2DStaticTileDistribution();
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
ATileAccessPattern,
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
}
@@ -530,24 +530,24 @@ struct UniversalGemmBasePolicy
// Tile: KPerBlock X NPerBlock
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern,
NumWaveGroups>;
return TileEncodingPattern::Make2DStaticTileDistribution();
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern,
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
// Tile: NPerBlock X KPerBlock
else
{
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
NPerBlock,
KPerBlock,
VecLoadSize,
BTileAccessPattern,
NumWaveGroups>;
return TileEncodingPattern::Make2DStaticTileDistribution();
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
NPerBlock,
KPerBlock,
VecLoadSize,
BTileAccessPattern,
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
}
@@ -562,13 +562,13 @@ struct UniversalGemmBasePolicy
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
ATileAccessPattern,
NumWaveGroups>;
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
ATileAccessPattern,
NumWaveGroups>;
return TileEncodingPattern::make_shuffled_2d_static_tile_distribution();
}
template <typename Problem>
@@ -582,13 +582,13 @@ struct UniversalGemmBasePolicy
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern,
NumWaveGroups>;
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern,
NumWaveGroups>;
return TileEncodingPattern::make_shuffled_2d_static_tile_distribution();
}
template <typename Problem>