mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
refactor: use snake_case naming in ck_tile/core components (#2766)
This commit is contained in:
@@ -104,7 +104,7 @@ enum struct tile_distribution_pattern
|
||||
block_raked,
|
||||
};
|
||||
|
||||
struct TileDistributionEncodingPattern
|
||||
struct tile_distribution_encoding_pattern
|
||||
{
|
||||
};
|
||||
|
||||
@@ -126,7 +126,7 @@ template <index_t BlockSize,
|
||||
index_t VecSize,
|
||||
tile_distribution_pattern DistributionPattern,
|
||||
index_t NumWaveGroups = 1>
|
||||
struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern
|
||||
struct tile_distribution_encoding_pattern_2d : public tile_distribution_encoding_pattern
|
||||
{
|
||||
};
|
||||
|
||||
@@ -136,12 +136,13 @@ template <index_t BlockSize,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
index_t NumWaveGroups>
|
||||
struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::thread_raked,
|
||||
NumWaveGroups> : public TileDistributionEncodingPattern
|
||||
struct tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::thread_raked,
|
||||
NumWaveGroups>
|
||||
: public tile_distribution_encoding_pattern
|
||||
{
|
||||
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
@@ -165,7 +166,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
"X0 * warp_ys * Y0 must cover whole workgroup!");
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
if constexpr(NumWaveGroups != 1)
|
||||
{
|
||||
@@ -189,7 +190,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_shuffled_2d_static_tile_distribution()
|
||||
{
|
||||
if constexpr(NumWaveGroups != 1)
|
||||
{
|
||||
@@ -220,12 +221,13 @@ template <index_t BlockSize,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
index_t NumWaveGroups>
|
||||
struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::warp_raked,
|
||||
NumWaveGroups> : public TileDistributionEncodingPattern
|
||||
struct tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::warp_raked,
|
||||
NumWaveGroups>
|
||||
: public tile_distribution_encoding_pattern
|
||||
{
|
||||
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
@@ -244,7 +246,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
static constexpr index_t Y1 = YPerTile / (Y2 * Y0); // # of iters within wavefront
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
@@ -255,7 +257,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
sequence<1, 1>>{}); // -> <Y1, X1>
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_shuffled_2d_static_tile_distribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
@@ -273,12 +275,13 @@ template <index_t BlockSize,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
index_t NumWaveGroups>
|
||||
struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::block_raked,
|
||||
NumWaveGroups> : public TileDistributionEncodingPattern
|
||||
struct tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::block_raked,
|
||||
NumWaveGroups>
|
||||
: public tile_distribution_encoding_pattern
|
||||
{
|
||||
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
@@ -295,7 +298,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
static constexpr index_t Y0 = YPerTile / (Y2 * Y1); // # of iters
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
@@ -306,7 +309,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
sequence<0, 1>>{}); // -> <Y0, X1>
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_shuffled_2d_static_tile_distribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
@@ -336,21 +339,21 @@ template <index_t BlockSize,
|
||||
index_t VecSize,
|
||||
tile_distribution_pattern DistributionPattern,
|
||||
index_t NumWaveGroups>
|
||||
CK_TILE_HOST_DEVICE void print(const TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
DistributionPattern,
|
||||
NumWaveGroups>&)
|
||||
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
DistributionPattern,
|
||||
NumWaveGroups>&)
|
||||
{
|
||||
using PatternType = TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
DistributionPattern,
|
||||
NumWaveGroups>;
|
||||
using PatternType = tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
DistributionPattern,
|
||||
NumWaveGroups>;
|
||||
|
||||
printf("TileDistributionEncodingPattern2D<BlockSize:%d, YPerTile:%d, XPerTile:%d, "
|
||||
printf("tile_distribution_encoding_pattern_2d<BlockSize:%d, YPerTile:%d, XPerTile:%d, "
|
||||
"VecSize:%d, %s>: ",
|
||||
BlockSize,
|
||||
YPerTile,
|
||||
|
||||
@@ -21,12 +21,12 @@ struct BatchedTransposeCommonPolicy
|
||||
|
||||
constexpr index_t kVectorSize = Problem::VectorSizeInput;
|
||||
static_assert((kLeadDimPerBlock * kVectorSize) % kBlockSize == 0, "");
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<kBlockSize,
|
||||
kSecondDimPerBlock,
|
||||
kLeadDimPerBlock,
|
||||
kVectorSize,
|
||||
TileAccessPattern>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<kBlockSize,
|
||||
kSecondDimPerBlock,
|
||||
kLeadDimPerBlock,
|
||||
kVectorSize,
|
||||
TileAccessPattern>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -18,12 +18,12 @@ struct BatchedTransposePolicy : public BatchedTransposeCommonPolicy
|
||||
constexpr index_t NPerBlock = Problem::kNPerBlock;
|
||||
constexpr index_t VecLoadSize = Problem::VectorSizeOutput;
|
||||
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
TileAccessPattern>;
|
||||
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
TileAccessPattern>;
|
||||
return TileEncodingPattern::make_shuffled_2d_static_tile_distribution();
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -291,13 +291,14 @@ struct CShuffleEpilogue
|
||||
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
|
||||
|
||||
using TileEncodingPattern =
|
||||
TileDistributionEncodingPattern2D<kBlockSize,
|
||||
MPerIterationShuffle,
|
||||
NPerIterationShuffle,
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked,
|
||||
Problem::kNumWaveGroups>;
|
||||
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
tile_distribution_encoding_pattern_2d<kBlockSize,
|
||||
MPerIterationShuffle,
|
||||
NPerIterationShuffle,
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked,
|
||||
Problem::kNumWaveGroups>;
|
||||
constexpr auto dram_tile_distribution =
|
||||
TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
|
||||
@@ -168,11 +168,11 @@ struct UniversalGemmBasePolicy
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern>;
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern>;
|
||||
|
||||
constexpr auto BK0 = number<TileEncodingPattern::X1>{};
|
||||
constexpr auto BK1 = number<TileEncodingPattern::Y0>{};
|
||||
@@ -494,24 +494,24 @@ struct UniversalGemmBasePolicy
|
||||
// Tile: MPerBlock X KPerBlock
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
VecLoadSize,
|
||||
ATileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
VecLoadSize,
|
||||
ATileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
// Tile: KPerBlock X MPerBlock
|
||||
else
|
||||
{
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
KPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
ATileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
KPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
ATileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -530,24 +530,24 @@ struct UniversalGemmBasePolicy
|
||||
// Tile: KPerBlock X NPerBlock
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
// Tile: NPerBlock X KPerBlock
|
||||
else
|
||||
{
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -562,13 +562,13 @@ struct UniversalGemmBasePolicy
|
||||
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
KPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
ATileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
KPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
ATileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_shuffled_2d_static_tile_distribution();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -582,13 +582,13 @@ struct UniversalGemmBasePolicy
|
||||
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_shuffled_2d_static_tile_distribution();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -55,44 +55,43 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
using TileEncodingPattern =
|
||||
TileDistributionEncodingPatternAQ<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
MPerBlock / WarpGemm::kM,
|
||||
ck_tile::integer_least_multiple(
|
||||
WarpGemm::kM * KPerBlockAQ, get_warp_size()),
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
PreshuffleQuant>;
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_aq<
|
||||
BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
MPerBlock / WarpGemm::kM,
|
||||
ck_tile::integer_least_multiple(WarpGemm::kM * KPerBlockAQ, get_warp_size()),
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
PreshuffleQuant>;
|
||||
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(Problem::TransposeC)
|
||||
{
|
||||
using TileEncodingPatternTransposeC =
|
||||
TileDistributionEncodingPatternAQTransposedC<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlockAQ,
|
||||
VecLoadSize>;
|
||||
return TileEncodingPatternTransposeC::Make2DStaticTileDistribution();
|
||||
tile_distribution_encoding_pattern_aq_transposed_c<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlockAQ,
|
||||
VecLoadSize>;
|
||||
return TileEncodingPatternTransposeC::make_2d_static_tile_distribution();
|
||||
}
|
||||
else
|
||||
{
|
||||
using TileEncodingPattern = TileDistributionEncodingPatternAQ<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlockAQ,
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
PreshuffleQuant>;
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_aq<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlockAQ,
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
PreshuffleQuant>;
|
||||
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -330,7 +330,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
|
||||
Policy::template make_shuffled_2d_static_tile_distribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
@@ -342,7 +342,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
|
||||
Policy::template make_shuffled_2d_static_tile_distribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
|
||||
@@ -52,14 +52,14 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
Problem::TransposeC>;
|
||||
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using TileEncodingPattern = TileDistributionEncodingPatternBQ<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
NPerBlock,
|
||||
KPerBlockBQ,
|
||||
VecLoadSize>;
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_bq<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
NPerBlock,
|
||||
KPerBlockBQ,
|
||||
VecLoadSize>;
|
||||
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -326,7 +326,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
|
||||
Policy::template make_shuffled_2d_static_tile_distribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
@@ -338,7 +338,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
|
||||
Policy::template make_shuffled_2d_static_tile_distribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ template <typename BlockGemmShape,
|
||||
index_t KPerBlockAQ,
|
||||
index_t VecSize,
|
||||
bool PreshuffleQuant>
|
||||
struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern
|
||||
struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding_pattern
|
||||
{
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
@@ -70,7 +70,7 @@ struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPatter
|
||||
// KWarps > 1 isn't supported
|
||||
static_assert(KWarps == 1);
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
@@ -119,7 +119,8 @@ template <typename BlockGemmShape,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize>
|
||||
struct TileDistributionEncodingPatternAQTransposedC : public TileDistributionEncodingPattern
|
||||
struct tile_distribution_encoding_pattern_aq_transposed_c
|
||||
: public tile_distribution_encoding_pattern
|
||||
{
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
@@ -152,7 +153,7 @@ struct TileDistributionEncodingPatternAQTransposedC : public TileDistributionEnc
|
||||
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<NWarps, XR>,
|
||||
@@ -171,7 +172,7 @@ template <typename BlockGemmShape,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize>
|
||||
struct TileDistributionEncodingPatternBQ : public TileDistributionEncodingPattern
|
||||
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
|
||||
{
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
@@ -204,7 +205,7 @@ struct TileDistributionEncodingPatternBQ : public TileDistributionEncodingPatter
|
||||
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, XR>,
|
||||
|
||||
@@ -32,13 +32,17 @@ TEST_F(PrintStaticEncodingPatternTest, PrintThreadRakedPattern)
|
||||
{
|
||||
// Test printing thread raked pattern
|
||||
using PatternType =
|
||||
TileDistributionEncodingPattern2D<64, 8, 16, 4, tile_distribution_pattern::thread_raked>;
|
||||
tile_distribution_encoding_pattern_2d<64,
|
||||
8,
|
||||
16,
|
||||
4,
|
||||
tile_distribution_pattern::thread_raked>;
|
||||
PatternType pattern;
|
||||
|
||||
std::string output = CapturePrintOutput(pattern);
|
||||
|
||||
// Verify the output contains expected information
|
||||
EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("tile_distribution_encoding_pattern_2d") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("BlockSize:64") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("YPerTile:8") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("XPerTile:16") != std::string::npos);
|
||||
@@ -52,13 +56,17 @@ TEST_F(PrintStaticEncodingPatternTest, PrintWarpRakedPattern)
|
||||
{
|
||||
// Test printing warp raked pattern
|
||||
using PatternType =
|
||||
TileDistributionEncodingPattern2D<128, 16, 32, 8, tile_distribution_pattern::warp_raked>;
|
||||
tile_distribution_encoding_pattern_2d<128,
|
||||
16,
|
||||
32,
|
||||
8,
|
||||
tile_distribution_pattern::warp_raked>;
|
||||
PatternType pattern;
|
||||
|
||||
std::string output = CapturePrintOutput(pattern);
|
||||
|
||||
// Verify the output contains expected information
|
||||
EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("tile_distribution_encoding_pattern_2d") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("BlockSize:128") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("YPerTile:16") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("XPerTile:32") != std::string::npos);
|
||||
@@ -72,13 +80,17 @@ TEST_F(PrintStaticEncodingPatternTest, PrintBlockRakedPattern)
|
||||
{
|
||||
// Test printing block raked pattern
|
||||
using PatternType =
|
||||
TileDistributionEncodingPattern2D<256, 32, 64, 16, tile_distribution_pattern::block_raked>;
|
||||
tile_distribution_encoding_pattern_2d<256,
|
||||
32,
|
||||
64,
|
||||
16,
|
||||
tile_distribution_pattern::block_raked>;
|
||||
PatternType pattern;
|
||||
|
||||
std::string output = CapturePrintOutput(pattern);
|
||||
|
||||
// Verify the output contains expected information
|
||||
EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("tile_distribution_encoding_pattern_2d") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("BlockSize:256") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("YPerTile:32") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("XPerTile:64") != std::string::npos);
|
||||
|
||||
Reference in New Issue
Block a user