Preshuffle AQ matrix in block scale gemm (#2624)

* Preshuffle AQ matrix in block scale gemm

* turns the output to fp16. Increase the repetition time.

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Cong Ma
2025-08-12 22:32:51 -06:00
committed by GitHub
parent 0f42a92fc1
commit 452791a3ba
13 changed files with 667 additions and 228 deletions

View File

@@ -24,7 +24,8 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize>
uint32_t QuantGroupSize,
bool Preshuffle = false>
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
@@ -55,7 +56,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits =
ck_tile::TileGemmAQuantTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
ck_tile::TileGemmAQuantTraits<kPadM, kPadN, kPadK, Preshuffle, ALayout, BLayout, CLayout>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
@@ -161,7 +162,8 @@ template <typename ADataType,
typename AQLayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize>
uint32_t QuantGroupSize,
bool Preshuffle = false>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& aq_m_aqk_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
@@ -202,7 +204,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ALayout,
BLayout,
CLayout,
QuantGroupSize>(
QuantGroupSize,
Preshuffle>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;