mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[CK_TILE] ABQuant New Preshuffle (#3638)
* Refactor * Gemm quant improvement * Change preshuffle * Fix * Fix grouped gemm ut * Fix --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -160,7 +160,7 @@ struct UniversalGemmBasePolicy
|
||||
constexpr auto K0PerThreadRead = AK0 / KThreadRead;
|
||||
|
||||
// check if we exceed all LDS banks
|
||||
constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_words_per_128b();
|
||||
constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_dwords_per_128b();
|
||||
constexpr auto kfold = (AK1 * M0 * sizeof(ADataType) > LdsBanksWidth)
|
||||
? 1
|
||||
: LdsBanksWidth / (AK1 * M0 * sizeof(ADataType));
|
||||
@@ -250,7 +250,7 @@ struct UniversalGemmBasePolicy
|
||||
constexpr uint64_t MinLdsLayer = 1ULL;
|
||||
constexpr auto MLdsLayer =
|
||||
max(MinLdsLayer,
|
||||
get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
|
||||
get_n_lds_banks() * get_n_dwords_per_128b() / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr index_t NBanks = get_n_lds_banks();
|
||||
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
|
||||
@@ -357,7 +357,7 @@ struct UniversalGemmBasePolicy
|
||||
constexpr auto K0PerThreadRead = BK0 / KThreadRead;
|
||||
|
||||
// check if we exceed all LDS banks
|
||||
constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_words_per_128b();
|
||||
constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_dwords_per_128b();
|
||||
constexpr auto kfold = (BK1 * N0 * sizeof(BDataType) > LdsBanksWidth)
|
||||
? 1
|
||||
: LdsBanksWidth / (BK1 * N0 * sizeof(BDataType));
|
||||
@@ -450,7 +450,7 @@ struct UniversalGemmBasePolicy
|
||||
constexpr uint64_t MinLdsLayer = 1ULL;
|
||||
constexpr auto NLdsLayer =
|
||||
max(MinLdsLayer,
|
||||
get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
|
||||
get_n_lds_banks() * get_n_dwords_per_128b() / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr index_t NBanks = get_n_lds_banks();
|
||||
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
|
||||
|
||||
@@ -151,6 +151,7 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
using BDataType = typename Problem::BDataType;
|
||||
|
||||
constexpr index_t kNPerBlock = TileShape::kN;
|
||||
constexpr index_t kKPerBlock = TileShape::kK;
|
||||
@@ -162,16 +163,18 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
|
||||
#if defined(__gfx11__)
|
||||
constexpr index_t KRepeatInWave = 2;
|
||||
#else
|
||||
constexpr index_t KRepeatInWave = 1;
|
||||
#endif
|
||||
constexpr index_t KBPerLoad = min(
|
||||
GetKBPerLoad<Problem>(), KRepeatInWave * 16 / static_cast<index_t>(sizeof(BDataType)));
|
||||
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = KIterPerWarp;
|
||||
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
constexpr index_t KAccess = GetKBPerLoad<Problem>() / KBPerLoad;
|
||||
static_assert(TileShape::flatKPerWarp == KAccess * KThdPerWave * KBPerLoad, "wrong");
|
||||
|
||||
constexpr index_t NBPerLoad = 1;
|
||||
constexpr index_t NThdPerWave = 1;
|
||||
@@ -181,16 +184,16 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat, KRepeatInWave>, // ?
|
||||
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
|
||||
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
sequence<WaveRepeat, KRepeatInWave>, // ?
|
||||
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
|
||||
sequence<KRepeat, KAccess, KWavePerBlk, KThdPerWave, KBPerLoad>>,
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
|
||||
tuple<sequence<0, 1, 1>, sequence<1, 2, 2>>, // which index
|
||||
tuple<sequence<0, 1, 2>, sequence<1, 2, 3>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<1, 2, 1, 2>,
|
||||
sequence<0, 0, 3, 3>>{});
|
||||
sequence<1, 2, 1, 2, 2>,
|
||||
sequence<0, 0, 3, 1, 4>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -256,13 +259,22 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
std::conditional_t<std::is_same_v<typename Problem::BDataType, ck_tile::pk_int4_t>,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
BTypeToUse,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize;
|
||||
using BDataType = typename Problem::BDataType;
|
||||
constexpr index_t KLaneBytes =
|
||||
KLane / numeric_traits<BDataType>::PackedSize * sizeof(BDataType);
|
||||
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
BTypeToUse,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
false,
|
||||
NumAccess>;
|
||||
|
||||
using BlockWeightPreshufflePolicy =
|
||||
BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
|
||||
Reference in New Issue
Block a user