feat(block_scale_gemm): Support RRR-R, CRR-R and CCR-C layout for aquant quant mode (#3193)

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Split cpp file to reduce building time
- Support multiple GemmConfig

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update Readme

* feat(gemm_quant): add RRR and CRR layout support for aquant gemm

* test(gemm_quant): add unit tests for RRR and CRR layout support for aquant gemm

* fix: compilation error on gfx950 by omitting support for the gpu in example and unit tests

* fix: test cases compilation failure due to PR# 2095

* fix: make condition to filter out tests for gfx950 more explicit

* need to support the gfx950

* fix: add layout suppot for gfx950

* Extend pk_int4_t support for block_scale_gemm aquant CR and RR layout (#3277)

* WIP: add support for pk_int4_t for aquant mode layouts RR and CR

* test(block_scale_gemm): add unit tests for CRR and RRR layout when data type is int4 && aquant

* fix: compile time error for gfx950

* fix: minor bug where is_a_load_tr_v() was mising

* feat(block_scale_gemm): Add layout Col-Col-Row-Col (ABC-Aquant) for tensors in aquant (#3318)

* feat(block_scale_gemm): Add layout Col-Col-Row-Col (ABC-Aquant) for tensors in aquant

* test: add unit tests for new layout support CCRC for aquant block scale gemm

* docs: update changelog with new layout support info

* Update CHANGELOG.md

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* refactor: break test instances into multiple cpp files to reduce build time (#3319)

* feat(block_scale_gemm): Add layout Col-Col-Row-Col (ABC-Aquant) for tensors in aquant

* test: add unit tests for new layout support CCRC for aquant block scale gemm

* refactor: break test instances into multiple cpp files to reduce build time

* chore: rename file for better code readability

* fix: merge conflict resolution

* fix: remove memory pipeline because new layout is not compatible

* build: resolve build errors for gfx950 by modifying is_a_load_tr() & is_b_load_tr()

* refactor: address review comments

* solve the conflict

---------

Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Aviral Goel
2025-12-03 02:59:07 +04:00
committed by GitHub
parent 2c284a1780
commit 6cb0bc2d11
22 changed files with 603 additions and 289 deletions

View File

@@ -435,12 +435,22 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_>(a_warp_tile_, a_block_window);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_>(b_warp_tile_, b_block_window);
// while ADatatype might not be the same as BDataType at the time of problem
// initialization, we can safely use BDataType here because when A would be int4 we will
// ensure A is converted to BDataType prior to loading
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_block_window);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_block_window);
}
// C += A * B
@@ -522,11 +532,16 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
MakeCBlockTile();
}
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window);
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
}
// C += A * B