[CK_Tile] Enable PreshuffleB for 2d block scale Gemm (#3298)

* formatted

* formatted

* formatting

* formatting

* formatting

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Split cpp file to reduce building time
- Support multiple GemmConfig

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update Readme

* enable prefill shapes

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Add support for rowcol and tensor GEMM operations

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update README

* adding preshuffle quant as new parameter and its associated new files

* remove debugging statements

* adding test

* enable preshuffle quant with permuteN

* updating readme and correcponding gemmconfigs

* updating cmake file

* fixing CI failures for grouped quant gemm

* debugging permuteN

* debugging

* debugging PermuteN

* initial commit

* resolving merge conflicts

* adding test cases

* fixing bq tensor calculation

---------

Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Khushbu Agarwal
2025-12-05 09:57:52 -08:00
committed by GitHub
parent 608232ce82
commit 6b1bceca7b
7 changed files with 257 additions and 36 deletions

View File

@@ -14,6 +14,7 @@ auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
}
int m_ = t->get_lengths()[0];
int aqk_ = t->get_lengths()[1];
if(aqk_ % block_aq_k != 0)
{
throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
@@ -110,7 +111,7 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
}
template <typename GemmConfig, typename T>
auto bq_permuteN(const ck_tile::HostTensor<T>& t)
auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
{
assert(t.get_lengths().size() == 2);
@@ -118,8 +119,11 @@ auto bq_permuteN(const ck_tile::HostTensor<T>& t)
int bqk_ = t.get_lengths()[0];
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
ck_tile::HostTensor<T> t_view(
{n_ / GemmConfig::N_Tile, GemmConfig::N_Warp, GemmConfig::N_Warp_Tile, NRepeat, bqk_});
ck_tile::HostTensor<T> t_view({n_ / (GemmConfig::N_Tile / group_n),
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile / group_n,
NRepeat,
bqk_});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4});
}