feat(block_scale_gemm): Support RRR-R, CRR-R and CCR-C layout for aquant quant mode (#3193)

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Split cpp file to reduce building time
- Support multiple GemmConfig

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update Readme

* feat(gemm_quant): add RRR and CRR layout support for aquant gemm

* test(gemm_quant): add unit tests for RRR and CRR layout support for aquant gemm

* fix: compilation error on gfx950 by omitting support for the gpu in example and unit tests

* fix: test cases compilation failure due to PR# 2095

* fix: make condition to filter out tests for gfx950 more explicit

* need to support the gfx950

* fix: add layout suppot for gfx950

* Extend pk_int4_t support for block_scale_gemm aquant CR and RR layout (#3277)

* WIP: add support for pk_int4_t for aquant mode layouts RR and CR

* test(block_scale_gemm): add unit tests for CRR and RRR layout when data type is int4 && aquant

* fix: compile time error for gfx950

* fix: minor bug where is_a_load_tr_v() was mising

* feat(block_scale_gemm): Add layout Col-Col-Row-Col (ABC-Aquant) for tensors in aquant (#3318)

* feat(block_scale_gemm): Add layout Col-Col-Row-Col (ABC-Aquant) for tensors in aquant

* test: add unit tests for new layout support CCRC for aquant block scale gemm

* docs: update changelog with new layout support info

* Update CHANGELOG.md

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* refactor: break test instances into multiple cpp files to reduce build time (#3319)

* feat(block_scale_gemm): Add layout Col-Col-Row-Col (ABC-Aquant) for tensors in aquant

* test: add unit tests for new layout support CCRC for aquant block scale gemm

* refactor: break test instances into multiple cpp files to reduce build time

* chore: rename file for better code readability

* fix: merge conflict resolution

* fix: remove memory pipeline because new layout is not compatible

* build: resolve build errors for gfx950 by modifying is_a_load_tr() & is_b_load_tr()

* refactor: address review comments

* solve the conflict

---------

Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Aviral Goel
2025-12-03 02:59:07 +04:00
committed by GitHub
parent 2c284a1780
commit 6cb0bc2d11
22 changed files with 603 additions and 289 deletions

View File

@@ -26,18 +26,32 @@ struct GemmPipelineAgBgCrImplBase
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
#if defined(__gfx950__)
// The combination of pk_int4_t and transposed loading causes numerical errors.
// The combination of pk_int4_t and transposed loading causes compilation errors.
// Therefore do not use transposed loading in this case.
// Also, transpose load (ds_read_tr) requires specific tile distribution patterns
// that only work for certain K warp tile sizes based on data type size:
// - For 1-byte types (fp8/bf8): K warp tile <= 64
// - For 2-byte types (fp16/bf16): K warp tile <= 32
static constexpr bool is_a_load_tr = []() {
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
constexpr index_t kMaxKWarpTile = (sizeof(ADataType) == 1) ? 64 : 32;
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else if constexpr(kKWarpTile > kMaxKWarpTile)
return false;
else
return std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
}();
static constexpr bool is_b_load_tr = []() {
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
constexpr index_t kMaxKWarpTile = (sizeof(BDataType) == 1) ? 64 : 32;
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else if constexpr(kKWarpTile > kMaxKWarpTile)
return false;
else
return std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
}();
@@ -93,19 +107,21 @@ struct GemmPipelineAgBgCrImplBase
load_tile(dst_block_tile, lds_tile_window);
}
template <typename OverrideADataType = ADataType, typename OverrideBDataType = BDataType>
CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const
{
// A tile in LDS
ADataType* __restrict__ p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
OverrideADataType* __restrict__ p_a_lds = static_cast<OverrideADataType*>(p_smem);
constexpr auto a_lds_block_desc =
Policy::template MakeALdsBlockDescriptor<Problem, OverrideADataType>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// TODO: LDS alignment should come from Policy!
constexpr index_t a_lds_block_space_size_aligned = integer_least_multiple(
sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16);
sizeof(OverrideADataType) * a_lds_block_desc.get_element_space_size(), 16);
// B tile in LDS
BDataType* __restrict__ p_b_lds = static_cast<BDataType*>(
OverrideBDataType* __restrict__ p_b_lds = static_cast<OverrideBDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);

View File

@@ -18,7 +18,8 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked;
static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked;
template <typename Problem>
template <typename Problem,
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;

View File

@@ -37,11 +37,22 @@ struct UniversalGemmBasePolicy
#if defined(__gfx950__)
// The combination of pk_int4_t and transposed loading causes numerical errors.
// Therefore do not use transposed loading in this case.
// Also, transpose load (ds_read_tr) requires specific tile distribution patterns
// that only work for certain K warp tile sizes based on data type size:
// - For 1-byte types (fp8/bf8): K warp tile <= 64
// - For 2-byte types (fp16/bf16): K warp tile <= 32
template <typename Problem>
static constexpr bool is_a_load_tr = []() {
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
// Max K warp tile for transpose load based on data type size
constexpr index_t kMaxKWarpTile = (sizeof(ADataType) == 1) ? 64 : 32;
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else if constexpr(kKWarpTile > kMaxKWarpTile)
return false;
else
return std::is_same_v<remove_cvref_t<typename Problem::ALayout>,
tensor_layout::gemm::ColumnMajor>;
@@ -49,9 +60,15 @@ struct UniversalGemmBasePolicy
template <typename Problem>
static constexpr bool is_b_load_tr = []() {
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
// Max K warp tile for transpose load based on data type size
constexpr index_t kMaxKWarpTile = (sizeof(BDataType) == 1) ? 64 : 32;
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else if constexpr(kKWarpTile > kMaxKWarpTile)
return false;
else
return std::is_same_v<remove_cvref_t<typename Problem::BLayout>,
tensor_layout::gemm::RowMajor>;
@@ -87,13 +104,12 @@ struct UniversalGemmBasePolicy
return DefaultBTileAccessPattern;
}
template <typename Problem>
template <typename Problem,
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
CK_TILE_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = OverrideADataType;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();