mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
refactor: use snake_case naming in ck_tile/core components (#2766)
This commit is contained in:
@@ -168,11 +168,11 @@ struct UniversalGemmBasePolicy
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern>;
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern>;
|
||||
|
||||
constexpr auto BK0 = number<TileEncodingPattern::X1>{};
|
||||
constexpr auto BK1 = number<TileEncodingPattern::Y0>{};
|
||||
@@ -494,24 +494,24 @@ struct UniversalGemmBasePolicy
|
||||
// Tile: MPerBlock X KPerBlock
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
VecLoadSize,
|
||||
ATileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
VecLoadSize,
|
||||
ATileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
// Tile: KPerBlock X MPerBlock
|
||||
else
|
||||
{
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
KPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
ATileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
KPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
ATileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -530,24 +530,24 @@ struct UniversalGemmBasePolicy
|
||||
// Tile: KPerBlock X NPerBlock
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
// Tile: NPerBlock X KPerBlock
|
||||
else
|
||||
{
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -562,13 +562,13 @@ struct UniversalGemmBasePolicy
|
||||
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
KPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
ATileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
KPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
ATileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_shuffled_2d_static_tile_distribution();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -582,13 +582,13 @@ struct UniversalGemmBasePolicy
|
||||
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_shuffled_2d_static_tile_distribution();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
Reference in New Issue
Block a user