refactor: use snake_case naming in ck_tile/core components (#2766)

This commit is contained in:
msaffari-amd
2025-09-03 09:34:11 +02:00
committed by GitHub
parent 4d041837ad
commit 47d020a993
11 changed files with 171 additions and 155 deletions

View File

@@ -55,44 +55,43 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
if constexpr(PreshuffleQuant)
{
using TileEncodingPattern =
TileDistributionEncodingPatternAQ<BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock / WarpGemm::kM,
ck_tile::integer_least_multiple(
WarpGemm::kM * KPerBlockAQ, get_warp_size()),
KPerBlockAQ,
VecLoadSize,
PreshuffleQuant>;
using TileEncodingPattern = tile_distribution_encoding_pattern_aq<
BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock / WarpGemm::kM,
ck_tile::integer_least_multiple(WarpGemm::kM * KPerBlockAQ, get_warp_size()),
KPerBlockAQ,
VecLoadSize,
PreshuffleQuant>;
return TileEncodingPattern::Make2DStaticTileDistribution();
return TileEncodingPattern::make_2d_static_tile_distribution();
}
else
{
if constexpr(Problem::TransposeC)
{
using TileEncodingPatternTransposeC =
TileDistributionEncodingPatternAQTransposedC<BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock,
KPerBlockAQ,
VecLoadSize>;
return TileEncodingPatternTransposeC::Make2DStaticTileDistribution();
tile_distribution_encoding_pattern_aq_transposed_c<BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock,
KPerBlockAQ,
VecLoadSize>;
return TileEncodingPatternTransposeC::make_2d_static_tile_distribution();
}
else
{
using TileEncodingPattern = TileDistributionEncodingPatternAQ<BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock,
KPerBlockAQ,
KPerBlockAQ,
VecLoadSize,
PreshuffleQuant>;
using TileEncodingPattern = tile_distribution_encoding_pattern_aq<BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock,
KPerBlockAQ,
KPerBlockAQ,
VecLoadSize,
PreshuffleQuant>;
return TileEncodingPattern::Make2DStaticTileDistribution();
return TileEncodingPattern::make_2d_static_tile_distribution();
}
}
}

View File

@@ -330,7 +330,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
Policy::template make_shuffled_2d_static_tile_distribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
@@ -342,7 +342,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
Policy::template make_shuffled_2d_static_tile_distribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}

View File

@@ -52,14 +52,14 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
Problem::TransposeC>;
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using TileEncodingPattern = TileDistributionEncodingPatternBQ<BlockGemmShape,
WarpGemm,
BlockSize,
NPerBlock,
KPerBlockBQ,
VecLoadSize>;
using TileEncodingPattern = tile_distribution_encoding_pattern_bq<BlockGemmShape,
WarpGemm,
BlockSize,
NPerBlock,
KPerBlockBQ,
VecLoadSize>;
return TileEncodingPattern::Make2DStaticTileDistribution();
return TileEncodingPattern::make_2d_static_tile_distribution();
}
template <typename Problem>

View File

@@ -326,7 +326,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
Policy::template make_shuffled_2d_static_tile_distribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
@@ -338,7 +338,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
Policy::template make_shuffled_2d_static_tile_distribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}

View File

@@ -53,7 +53,7 @@ template <typename BlockGemmShape,
index_t KPerBlockAQ,
index_t VecSize,
bool PreshuffleQuant>
struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern
struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding_pattern
{
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
@@ -70,7 +70,7 @@ struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPatter
// KWarps > 1 isn't supported
static_assert(KWarps == 1);
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
if constexpr(PreshuffleQuant)
{
@@ -119,7 +119,8 @@ template <typename BlockGemmShape,
index_t YPerTile,
index_t XPerTile,
index_t VecSize>
struct TileDistributionEncodingPatternAQTransposedC : public TileDistributionEncodingPattern
struct tile_distribution_encoding_pattern_aq_transposed_c
: public tile_distribution_encoding_pattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
@@ -152,7 +153,7 @@ struct TileDistributionEncodingPatternAQTransposedC : public TileDistributionEnc
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps, XR>,
@@ -171,7 +172,7 @@ template <typename BlockGemmShape,
index_t YPerTile,
index_t XPerTile,
index_t VecSize>
struct TileDistributionEncodingPatternBQ : public TileDistributionEncodingPattern
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
@@ -204,7 +205,7 @@ struct TileDistributionEncodingPatternBQ : public TileDistributionEncodingPatter
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, XR>,