mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
refactor: use snake_case naming in ck_tile/core components (#2766)
This commit is contained in:
@@ -55,44 +55,43 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
using TileEncodingPattern =
|
||||
TileDistributionEncodingPatternAQ<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
MPerBlock / WarpGemm::kM,
|
||||
ck_tile::integer_least_multiple(
|
||||
WarpGemm::kM * KPerBlockAQ, get_warp_size()),
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
PreshuffleQuant>;
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_aq<
|
||||
BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
MPerBlock / WarpGemm::kM,
|
||||
ck_tile::integer_least_multiple(WarpGemm::kM * KPerBlockAQ, get_warp_size()),
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
PreshuffleQuant>;
|
||||
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(Problem::TransposeC)
|
||||
{
|
||||
using TileEncodingPatternTransposeC =
|
||||
TileDistributionEncodingPatternAQTransposedC<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlockAQ,
|
||||
VecLoadSize>;
|
||||
return TileEncodingPatternTransposeC::Make2DStaticTileDistribution();
|
||||
tile_distribution_encoding_pattern_aq_transposed_c<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlockAQ,
|
||||
VecLoadSize>;
|
||||
return TileEncodingPatternTransposeC::make_2d_static_tile_distribution();
|
||||
}
|
||||
else
|
||||
{
|
||||
using TileEncodingPattern = TileDistributionEncodingPatternAQ<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlockAQ,
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
PreshuffleQuant>;
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_aq<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlockAQ,
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
PreshuffleQuant>;
|
||||
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -330,7 +330,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
|
||||
Policy::template make_shuffled_2d_static_tile_distribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
@@ -342,7 +342,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
|
||||
Policy::template make_shuffled_2d_static_tile_distribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
|
||||
@@ -52,14 +52,14 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
Problem::TransposeC>;
|
||||
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using TileEncodingPattern = TileDistributionEncodingPatternBQ<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
NPerBlock,
|
||||
KPerBlockBQ,
|
||||
VecLoadSize>;
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_bq<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
NPerBlock,
|
||||
KPerBlockBQ,
|
||||
VecLoadSize>;
|
||||
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -326,7 +326,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
|
||||
Policy::template make_shuffled_2d_static_tile_distribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
@@ -338,7 +338,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
|
||||
Policy::template make_shuffled_2d_static_tile_distribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ template <typename BlockGemmShape,
|
||||
index_t KPerBlockAQ,
|
||||
index_t VecSize,
|
||||
bool PreshuffleQuant>
|
||||
struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern
|
||||
struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding_pattern
|
||||
{
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
@@ -70,7 +70,7 @@ struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPatter
|
||||
// KWarps > 1 isn't supported
|
||||
static_assert(KWarps == 1);
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
@@ -119,7 +119,8 @@ template <typename BlockGemmShape,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize>
|
||||
struct TileDistributionEncodingPatternAQTransposedC : public TileDistributionEncodingPattern
|
||||
struct tile_distribution_encoding_pattern_aq_transposed_c
|
||||
: public tile_distribution_encoding_pattern
|
||||
{
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
@@ -152,7 +153,7 @@ struct TileDistributionEncodingPatternAQTransposedC : public TileDistributionEnc
|
||||
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<NWarps, XR>,
|
||||
@@ -171,7 +172,7 @@ template <typename BlockGemmShape,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize>
|
||||
struct TileDistributionEncodingPatternBQ : public TileDistributionEncodingPattern
|
||||
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
|
||||
{
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
@@ -204,7 +205,7 @@ struct TileDistributionEncodingPatternBQ : public TileDistributionEncodingPatter
|
||||
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, XR>,
|
||||
|
||||
Reference in New Issue
Block a user