Make CK TILE GEMM Aquant support block tile 128x128x128 (#3325)

* [CK TILE GEMM Quant] Rename GemmConfigBQuantPrefill to GemmConfigQuantPrefill in examples

* [CK TILE GEMM Quant] update tile distribution of aquant

* [CK TILE GEMM Quant] update aquant register offset calculation

* [CK TILE GEMM Quant] Reimplement aquant register offset calculation

* [CK TILE GEMM Quant] Add more unit tests of Aquant

- Test M128xN128xK128

* [CK TILE GEMM Quant] Add more comments to Gemm Aquant
This commit is contained in:
Cong Ma
2025-12-01 16:04:37 -07:00
committed by GitHub
parent 7873f8fa13
commit 23fb253c4e
11 changed files with 58 additions and 46 deletions

View File

@@ -74,7 +74,7 @@ User need to select correct mapping of config for each quant mode:
|:--------|:-----:|:-----:|-------|
| For selecting AQuant | aquant | gemm_aquant_quantgrouped.cpp| GemmConfigQuantDecode |
| For selecting AQuant with Preshuffle quant | aquant | gemm_aquant_quantgrouped_preshufflequant.cpp | GemmConfigPreshuffleQuantDecode |
| For selecting BQuant | bquant | gemm_bquant_quantgrouped_<prec_type>.cpp| GemmConfigQuantDecode (or) GemmConfigBQuantPrefill |
| For selecting BQuant | bquant | gemm_bquant_quantgrouped_<prec_type>.cpp| GemmConfigQuantDecode (or) GemmConfigQuantPrefill |
| For selecting BQuant with Preshuffle quant | bquant | gemm_bquant_quantgrouped_preshufflequant.cpp| GemmConfigPreshuffleQuantDecode (or) GemmConfigPreshuffleBQuantPrefill |
| For selecting PreShuffle B with BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb.cpp| GemmConfigPreshuffleB_BQuant_Decode (or) GemmConfigPreshuffleB_BQuant_Prefill
| For selecting PreShuffle B with preshuffle BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp |GemmConfigPreshuffleB_PreshuffleBQuant_Decode (or) GemmConfigPreshuffleB_PreshuffleBQuant_Prefill

View File

@@ -6,6 +6,10 @@
template <typename T>
using GemmConfig = GemmConfigQuantDecode<T>;
// GemmConfigQuantPrefill is also supported for aquant grouped quantization
// template <typename T>
// using GemmConfig = GemmConfigQuantPrefill<T>;
void aquant_quantgrouped_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigBQuantPrefill<T>;
using GemmConfig = GemmConfigQuantPrefill<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigBQuantPrefill<T>;
using GemmConfig = GemmConfigQuantPrefill<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigBQuantPrefill<T>;
using GemmConfig = GemmConfigQuantPrefill<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigBQuantPrefill<T>;
using GemmConfig = GemmConfigQuantPrefill<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \

View File

@@ -221,7 +221,7 @@ struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
};
template <typename PrecType>
struct GemmConfigBQuantPrefill : public GemmConfigBase
struct GemmConfigQuantPrefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -237,13 +237,13 @@ struct GemmConfigBQuantPrefill : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigBQuantPrefill<PrecType>
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill<PrecType>
{
static constexpr bool PreshuffleQuant = true;
};
template <typename PrecType>
struct GemmConfigBQuantPrefill_Wmma : public GemmConfigBQuantPrefill<PrecType>
struct GemmConfigBQuantPrefill_Wmma : public GemmConfigQuantPrefill<PrecType>
{
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;