Add Memory pipeline for AQuant Block Scale GEMM (#2987)

* WIP: add memory pipeline boiler plate code that compiles and works for one block

* WIP: tail handling works for memory pipeline

* WIP: numerical errors appears to have gone by adding block_sync_lds()

* fix: numerical error with memory pipeline by adding block_sync_lds() and new tail handler

* refactror: remove debug print statements and lints

* fix: remove redundant sync barriars

* chore: remove lint

* fix: remove unused code from tile handler and remove redundant block_sync_lds()

* fix: correct parent struct name for memory pipeline

* fix: remove static assert check from parent struct and add it to child struct because not all child structs needs to static assert

* fix: defer block sync lds to just before prefill
This commit is contained in:
Aviral Goel
2025-10-08 20:22:30 -04:00
committed by GitHub
parent e29151b533
commit e99356dabc
6 changed files with 489 additions and 12 deletions

View File

@@ -59,7 +59,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
using BaseGemmPipeline = std::conditional_t<
GemmConfig::PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
ck_tile::BaseAQuantGemmPipelineAgBgCrMem<GemmPipelineProblem>>; // memory pipeline hardcoded
// for aquant
const ck_tile::index_t K_split =
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
@@ -118,7 +119,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>, // memory pipeline hardcoded
// for aquant
std::conditional_t<GemmConfig::PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>;
@@ -448,7 +450,4 @@ int run_gemm_example(int argc, char* argv[])
}
}
int main(int argc, char* argv[])
{
return !run_gemm_example<GemmConfigPreshuffleB_Bquant_decode>(argc, argv);
}
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigQuant>(argc, argv); }

View File

@@ -182,7 +182,7 @@ int run_gemm_example_with_layouts(int argc,
if(K % QuantGroupSize != 0)
{
throw std::runtime_error(
"K must be aligned with QuantGroupSize for AQuantGrouped mode");
"K must be aligned with QuantGroupSize for AQuantGrouped/BQuantGrouped mode");
}
}
ck_tile::index_t AQK, BQK;
@@ -204,7 +204,7 @@ int run_gemm_example_with_layouts(int argc,
}
else
{
static_assert(false, "Unsupported QuantMode");
throw std::runtime_error("Unsupported QuantMode");
}
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");