feat(block_scale_gemm): Support RRR-R, CRR-R and CCR-C layout for aquant quant mode (#3193)

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Split cpp file to reduce building time
- Support multiple GemmConfig

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update Readme

* feat(gemm_quant): add RRR and CRR layout support for aquant gemm

* test(gemm_quant): add unit tests for RRR and CRR layout support for aquant gemm

* fix: compilation error on gfx950 by omitting support for the gpu in example and unit tests

* fix: test cases compilation failure due to PR# 2095

* fix: make condition to filter out tests for gfx950 more explicit

* need to support the gfx950

* fix: add layout suppot for gfx950

* Extend pk_int4_t support for block_scale_gemm aquant CR and RR layout (#3277)

* WIP: add support for pk_int4_t for aquant mode layouts RR and CR

* test(block_scale_gemm): add unit tests for CRR and RRR layout when data type is int4 && aquant

* fix: compile time error for gfx950

* fix: minor bug where is_a_load_tr_v() was mising

* feat(block_scale_gemm): Add layout Col-Col-Row-Col (ABC-Aquant) for tensors in aquant (#3318)

* feat(block_scale_gemm): Add layout Col-Col-Row-Col (ABC-Aquant) for tensors in aquant

* test: add unit tests for new layout support CCRC for aquant block scale gemm

* docs: update changelog with new layout support info

* Update CHANGELOG.md

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* refactor: break test instances into multiple cpp files to reduce build time (#3319)

* feat(block_scale_gemm): Add layout Col-Col-Row-Col (ABC-Aquant) for tensors in aquant

* test: add unit tests for new layout support CCRC for aquant block scale gemm

* refactor: break test instances into multiple cpp files to reduce build time

* chore: rename file for better code readability

* fix: merge conflict resolution

* fix: remove memory pipeline because new layout is not compatible

* build: resolve build errors for gfx950 by modifying is_a_load_tr() & is_b_load_tr()

* refactor: address review comments

* solve the conflict

---------

Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Aviral Goel
2025-12-03 02:59:07 +04:00
committed by GitHub
parent 2c284a1780
commit 6cb0bc2d11
22 changed files with 603 additions and 289 deletions

View File

@@ -21,7 +21,9 @@
template <typename GemmConfig,
typename TypeConfig,
typename ALayout,
typename AQLayout,
typename BLayout,
typename BQLayout,
typename CLayout,
typename QuantGroupSize,
ck_tile::QuantType QuantMode,
@@ -51,8 +53,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
BLayout,
CLayout,
QuantMode,
ALayout, // for AQLayout
BLayout, // for BQLayout
AQLayout, // for AQLayout
BQLayout, // for BQLayout
false,
GemmConfig::DoubleSmemBuffer>;
@@ -67,12 +69,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
using BaseGemmPipeline = std::conditional_t<
GemmConfig::PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>>>;
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
const ck_tile::index_t K_split =
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
@@ -131,9 +128,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
std::conditional_t<GemmConfig::PreshuffleQuant == true,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<GemmConfig::PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>;
@@ -289,7 +284,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
float ave_time = gemm_calc_quant<GemmConfig,
TypeConfig,
ALayout,
AQLayout,
BLayout,
BQLayout,
CLayout,
QuantGroupSize,
QuantMode,
@@ -317,7 +314,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideAQ =" << stride_AQ << " StrideB =" << stride_B
<< " StrideC =" << stride_C << " A_Layout =" << ALayout::name
<< " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name;
<< " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name
<< " AQ_Layout =" << AQLayout::name << " BQ_Layout =" << BQLayout::name;
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
{
@@ -792,6 +790,39 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
return run_gemm_example_with_layouts<GemmConfig, TypeConfig, QuantGroupSize, QuantMode>(
arg_parser, Row{}, Row{}, Col{}, Col{}, Row{});
}
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped && !GemmConfig::PreshuffleQuant)
{
if(a_layout == "R" && b_layout == "R")
{
return run_gemm_example_with_layouts<GemmConfig,
TypeConfig,
QuantGroupSize,
QuantMode>(
arg_parser, Row{}, Row{}, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_gemm_example_with_layouts<GemmConfig,
TypeConfig,
QuantGroupSize,
QuantMode>(
arg_parser, Col{}, Row{}, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig,
TypeConfig,
QuantGroupSize,
QuantMode>(
arg_parser, Col{}, Col{}, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
}
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");