[CK_Tile] Enable PreshuffleB for 2d block scale Gemm (#3298)

* formatted

* formatted

* formatting

* formatting

* formatting

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Split cpp file to reduce building time
- Support multiple GemmConfig

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update Readme

* enable prefill shapes

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Add support for rowcol and tensor GEMM operations

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update README

* adding preshuffle quant as new parameter and its associated new files

* remove debugging statements

* adding test

* enable preshuffle quant with permuteN

* updating readme and correcponding gemmconfigs

* updating cmake file

* fixing CI failures for grouped quant gemm

* debugging permuteN

* debugging

* debugging PermuteN

* initial commit

* resolving merge conflicts

* adding test cases

* fixing bq tensor calculation

---------

Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Khushbu Agarwal
2025-12-05 09:57:52 -08:00
committed by GitHub
parent 608232ce82
commit 6b1bceca7b
7 changed files with 257 additions and 36 deletions

View File

@@ -14,6 +14,7 @@ auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
}
int m_ = t->get_lengths()[0];
int aqk_ = t->get_lengths()[1];
if(aqk_ % block_aq_k != 0)
{
throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
@@ -110,7 +111,7 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
}
template <typename GemmConfig, typename T>
auto bq_permuteN(const ck_tile::HostTensor<T>& t)
auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
{
assert(t.get_lengths().size() == 2);
@@ -118,8 +119,11 @@ auto bq_permuteN(const ck_tile::HostTensor<T>& t)
int bqk_ = t.get_lengths()[0];
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
ck_tile::HostTensor<T> t_view(
{n_ / GemmConfig::N_Tile, GemmConfig::N_Warp, GemmConfig::N_Warp_Tile, NRepeat, bqk_});
ck_tile::HostTensor<T> t_view({n_ / (GemmConfig::N_Tile / group_n),
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile / group_n,
NRepeat,
bqk_});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4});
}

View File

@@ -28,7 +28,6 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!");
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
@@ -205,7 +204,17 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
}
else
{
constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale;
index_t reg_offset = [&]() {
if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN))
{
return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ +
kQScale;
}
else
{
return nIter * KPerBlockBQ + kQScale;
}
}();
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f = cvt_scale_to_fp32(scale_reg);

View File

@@ -747,7 +747,6 @@ struct QuantGemmKernel
(splitk_batch_offset.splitted_k /
GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kFlatN, kFlatK),