mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Add Memory pipeline for AQuant Block Scale GEMM (#2987)
* WIP: add memory pipeline boiler plate code that compiles and works for one block * WIP: tail handling works for memory pipeline * WIP: numerical errors appears to have gone by adding block_sync_lds() * fix: numerical error with memory pipeline by adding block_sync_lds() and new tail handler * refactror: remove debug print statements and lints * fix: remove redundant sync barriars * chore: remove lint * fix: remove unused code from tile handler and remove redundant block_sync_lds() * fix: correct parent struct name for memory pipeline * fix: remove static assert check from parent struct and add it to child struct because not all child structs needs to static assert * fix: defer block sync lds to just before prefill
This commit is contained in:
11
example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp
Executable file → Normal file
11
example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp
Executable file → Normal file
@@ -59,7 +59,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
using BaseGemmPipeline = std::conditional_t<
|
||||
GemmConfig::PreshuffleB == true,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
|
||||
ck_tile::BaseAQuantGemmPipelineAgBgCrMem<GemmPipelineProblem>>; // memory pipeline hardcoded
|
||||
// for aquant
|
||||
|
||||
const ck_tile::index_t K_split =
|
||||
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
|
||||
@@ -118,7 +119,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>, // memory pipeline hardcoded
|
||||
// for aquant
|
||||
std::conditional_t<GemmConfig::PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>;
|
||||
@@ -448,7 +450,4 @@ int run_gemm_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return !run_gemm_example<GemmConfigPreshuffleB_Bquant_decode>(argc, argv);
|
||||
}
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigQuant>(argc, argv); }
|
||||
|
||||
4
example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc
Executable file → Normal file
4
example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc
Executable file → Normal file
@@ -182,7 +182,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
if(K % QuantGroupSize != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"K must be aligned with QuantGroupSize for AQuantGrouped mode");
|
||||
"K must be aligned with QuantGroupSize for AQuantGrouped/BQuantGrouped mode");
|
||||
}
|
||||
}
|
||||
ck_tile::index_t AQK, BQK;
|
||||
@@ -204,7 +204,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported QuantMode");
|
||||
throw std::runtime_error("Unsupported QuantMode");
|
||||
}
|
||||
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
|
||||
|
||||
Reference in New Issue
Block a user