[CK_TILE] ABQuant New Preshuffle (#3638)

* Refactor

* Gemm quant improvement

* Change preshuffle

* Fix

* Fix grouped gemm ut

* Fix

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Yi DING
2026-01-28 15:46:49 +08:00
committed by GitHub
parent 91e32f305f
commit 8e3d84aba3
32 changed files with 182 additions and 213 deletions

View File

@@ -1137,7 +1137,7 @@ CK_TILE_DEVICE static constexpr auto get_device_arch()
#endif
}
CK_TILE_DEVICE static constexpr auto get_n_words_per_128b() { return 4; }
CK_TILE_DEVICE static constexpr auto get_n_dwords_per_128b() { return 4; }
namespace detail {
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx9_t) { return 32; }

View File

@@ -69,7 +69,7 @@ auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
}
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
auto shuffle_b(const ck_tile::HostTensor<T>& t, GemmConfig)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
@@ -79,36 +79,40 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
gemmConfig.N_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
kABK0PerLane,
divisor,
kABK1PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
}
else
else if(ck_tile::is_gfx11_supported())
{
int divisor = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
}
else
{
assert(is_wave32() == false);
divisor = get_warp_size() / gemmConfig.N_Warp_Tile;
}
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
gemmConfig.N_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
gemmConfig.K_Warp_Tile / divisor});
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
else
{
constexpr int KLane = ck_tile::get_warp_size() / GemmConfig::N_Warp_Tile;
constexpr int ItemsPerAccess =
std::min(16 / static_cast<int>(sizeof(T)), GemmConfig::K_Warp_Tile / KLane);
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / ItemsPerAccess,
ItemsPerAccess});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 1, 3});
}
}
template <typename GemmConfig, typename T>

View File

@@ -160,7 +160,7 @@ struct UniversalGemmBasePolicy
constexpr auto K0PerThreadRead = AK0 / KThreadRead;
// check if we exceed all LDS banks
constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_words_per_128b();
constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_dwords_per_128b();
constexpr auto kfold = (AK1 * M0 * sizeof(ADataType) > LdsBanksWidth)
? 1
: LdsBanksWidth / (AK1 * M0 * sizeof(ADataType));
@@ -250,7 +250,7 @@ struct UniversalGemmBasePolicy
constexpr uint64_t MinLdsLayer = 1ULL;
constexpr auto MLdsLayer =
max(MinLdsLayer,
get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
get_n_lds_banks() * get_n_dwords_per_128b() / KPerBlock / DataTypeSize);
constexpr index_t NBanks = get_n_lds_banks();
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
@@ -357,7 +357,7 @@ struct UniversalGemmBasePolicy
constexpr auto K0PerThreadRead = BK0 / KThreadRead;
// check if we exceed all LDS banks
constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_words_per_128b();
constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_dwords_per_128b();
constexpr auto kfold = (BK1 * N0 * sizeof(BDataType) > LdsBanksWidth)
? 1
: LdsBanksWidth / (BK1 * N0 * sizeof(BDataType));
@@ -450,7 +450,7 @@ struct UniversalGemmBasePolicy
constexpr uint64_t MinLdsLayer = 1ULL;
constexpr auto NLdsLayer =
max(MinLdsLayer,
get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
get_n_lds_banks() * get_n_dwords_per_128b() / KPerBlock / DataTypeSize);
constexpr index_t NBanks = get_n_lds_banks();
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");

View File

@@ -151,6 +151,7 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
using BDataType = typename Problem::BDataType;
constexpr index_t kNPerBlock = TileShape::kN;
constexpr index_t kKPerBlock = TileShape::kK;
@@ -162,16 +163,18 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
#if defined(__gfx11__)
constexpr index_t KRepeatInWave = 2;
#else
constexpr index_t KRepeatInWave = 1;
#endif
constexpr index_t KBPerLoad = min(
GetKBPerLoad<Problem>(), KRepeatInWave * 16 / static_cast<index_t>(sizeof(BDataType)));
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = KIterPerWarp;
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
constexpr index_t KAccess = GetKBPerLoad<Problem>() / KBPerLoad;
static_assert(TileShape::flatKPerWarp == KAccess * KThdPerWave * KBPerLoad, "wrong");
constexpr index_t NBPerLoad = 1;
constexpr index_t NThdPerWave = 1;
@@ -181,16 +184,16 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<WaveRepeat, KRepeatInWave>, // ?
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
sequence<WaveRepeat, KRepeatInWave>, // ?
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
sequence<KRepeat, KAccess, KWavePerBlk, KThdPerWave, KBPerLoad>>,
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
tuple<sequence<0, 1, 1>, sequence<1, 2, 2>>, // which index
tuple<sequence<0, 1, 2>, sequence<1, 2, 3>>, // which index
// <repeat, vec_load>
sequence<1, 2, 1, 2>,
sequence<0, 0, 3, 3>>{});
sequence<1, 2, 1, 2, 2>,
sequence<0, 0, 3, 1, 4>>{});
}
template <typename Problem>
@@ -256,13 +259,22 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
std::conditional_t<std::is_same_v<typename Problem::BDataType, ck_tile::pk_int4_t>,
typename Problem::ADataType,
typename Problem::BDataType>;
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
BTypeToUse,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize;
using BDataType = typename Problem::BDataType;
constexpr index_t KLaneBytes =
KLane / numeric_traits<BDataType>::PackedSize * sizeof(BDataType);
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
BTypeToUse,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
NumAccess>;
using BlockWeightPreshufflePolicy =
BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,

View File

@@ -693,13 +693,13 @@ struct QuantGemmKernel
{
if constexpr(PreshuffleB)
{
index_t kFlatK =
GemmPipeline::flatKPerWarp *
(k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
constexpr auto warp_k = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
index_t kFlatKSplit = GemmPipeline::flatKPerWarp * (k_size / warp_k);
index_t kFlatK = GemmPipeline::flatKPerWarp * (kargs.K / warp_k);
index_t kFlatN = kargs.N * kargs.K / kFlatK;
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatN, kFlatKSplit),
make_tuple(kFlatK, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});

View File

@@ -52,11 +52,13 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
using BDataType = typename Problem::BDataType;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
constexpr index_t KBPerLoad =
min(GetKBPerLoad<Problem>(), 16 / static_cast<index_t>(sizeof(BDataType)));
#if defined(__gfx11__)
constexpr index_t KRepeatInWave = 2;
#else
@@ -64,8 +66,8 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel
#endif
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
constexpr index_t KRepeat = GetKBPerLoad<Problem>() / KBPerLoad;
static_assert(TileShape::flatKPerWarp == KRepeat * KThdPerWave * KBPerLoad, "wrong");
constexpr index_t NBPerLoad = 1;
constexpr index_t NThdPerWave = 1;
@@ -98,13 +100,23 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel
typename Problem::ADataType,
typename Problem::BDataType>;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize;
using BDataType = typename Problem::BDataType;
constexpr index_t KLaneBytes =
KLane / numeric_traits<BDataType>::PackedSize * sizeof(BDataType);
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
BTypeToUse,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
Problem::TransposeC,
false,
false,
NumAccess>;
// TODO : Use a custom block policy for AsBrCr
using BlockGemmPolicy =