mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[CK_TILE] ABQuant New Preshuffle (#3638)
* Refactor * Gemm quant improvement * Change preshuffle * Fix * Fix grouped gemm ut * Fix --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -693,13 +693,13 @@ struct QuantGemmKernel
|
||||
{
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
index_t kFlatK =
|
||||
GemmPipeline::flatKPerWarp *
|
||||
(k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
constexpr auto warp_k = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
|
||||
index_t kFlatKSplit = GemmPipeline::flatKPerWarp * (k_size / warp_k);
|
||||
index_t kFlatK = GemmPipeline::flatKPerWarp * (kargs.K / warp_k);
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatN, kFlatKSplit),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -52,11 +52,13 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel
|
||||
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
using BDataType = typename Problem::BDataType;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
|
||||
constexpr index_t KBPerLoad =
|
||||
min(GetKBPerLoad<Problem>(), 16 / static_cast<index_t>(sizeof(BDataType)));
|
||||
#if defined(__gfx11__)
|
||||
constexpr index_t KRepeatInWave = 2;
|
||||
#else
|
||||
@@ -64,8 +66,8 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel
|
||||
#endif
|
||||
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = 1;
|
||||
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
constexpr index_t KRepeat = GetKBPerLoad<Problem>() / KBPerLoad;
|
||||
static_assert(TileShape::flatKPerWarp == KRepeat * KThdPerWave * KBPerLoad, "wrong");
|
||||
|
||||
constexpr index_t NBPerLoad = 1;
|
||||
constexpr index_t NThdPerWave = 1;
|
||||
@@ -98,13 +100,23 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize;
|
||||
using BDataType = typename Problem::BDataType;
|
||||
constexpr index_t KLaneBytes =
|
||||
KLane / numeric_traits<BDataType>::PackedSize * sizeof(BDataType);
|
||||
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
BTypeToUse,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
false,
|
||||
NumAccess>;
|
||||
|
||||
// TODO : Use a custom block policy for AsBrCr
|
||||
using BlockGemmPolicy =
|
||||
|
||||
Reference in New Issue
Block a user