mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
feat(block_scale_gemm): Support RRR-R, CRR-R and CCR-C layout for aquant quant mode (#3193)
* [CK TILE GEMM] Refactor block_scale_gemm examples - Split cpp file to reduce building time - Support multiple GemmConfig * [CK TILE GEMM] Refactor block_scale_gemm examples - Update Readme * feat(gemm_quant): add RRR and CRR layout support for aquant gemm * test(gemm_quant): add unit tests for RRR and CRR layout support for aquant gemm * fix: compilation error on gfx950 by omitting support for the gpu in example and unit tests * fix: test cases compilation failure due to PR# 2095 * fix: make condition to filter out tests for gfx950 more explicit * need to support the gfx950 * fix: add layout suppot for gfx950 * Extend pk_int4_t support for block_scale_gemm aquant CR and RR layout (#3277) * WIP: add support for pk_int4_t for aquant mode layouts RR and CR * test(block_scale_gemm): add unit tests for CRR and RRR layout when data type is int4 && aquant * fix: compile time error for gfx950 * fix: minor bug where is_a_load_tr_v() was mising * feat(block_scale_gemm): Add layout Col-Col-Row-Col (ABC-Aquant) for tensors in aquant (#3318) * feat(block_scale_gemm): Add layout Col-Col-Row-Col (ABC-Aquant) for tensors in aquant * test: add unit tests for new layout support CCRC for aquant block scale gemm * docs: update changelog with new layout support info * Update CHANGELOG.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * refactor: break test instances into multiple cpp files to reduce build time (#3319) * feat(block_scale_gemm): Add layout Col-Col-Row-Col (ABC-Aquant) for tensors in aquant * test: add unit tests for new layout support CCRC for aquant block scale gemm * refactor: break test instances into multiple cpp files to reduce build time * chore: rename file for better code readability * fix: merge conflict resolution * fix: remove memory pipeline because new layout is not compatible * build: resolve build errors for gfx950 by modifying is_a_load_tr() & is_b_load_tr() * refactor: address review comments * solve the conflict --------- Co-authored-by: Cong Ma <congma13@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -26,18 +26,32 @@ struct GemmPipelineAgBgCrImplBase
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
#if defined(__gfx950__)
|
||||
// The combination of pk_int4_t and transposed loading causes numerical errors.
|
||||
// The combination of pk_int4_t and transposed loading causes compilation errors.
|
||||
// Therefore do not use transposed loading in this case.
|
||||
// Also, transpose load (ds_read_tr) requires specific tile distribution patterns
|
||||
// that only work for certain K warp tile sizes based on data type size:
|
||||
// - For 1-byte types (fp8/bf8): K warp tile <= 64
|
||||
// - For 2-byte types (fp16/bf16): K warp tile <= 32
|
||||
static constexpr bool is_a_load_tr = []() {
|
||||
using WarpTile = typename BlockGemmShape::WarpTile;
|
||||
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
|
||||
constexpr index_t kMaxKWarpTile = (sizeof(ADataType) == 1) ? 64 : 32;
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
return false;
|
||||
else if constexpr(kKWarpTile > kMaxKWarpTile)
|
||||
return false;
|
||||
else
|
||||
return std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
}();
|
||||
|
||||
static constexpr bool is_b_load_tr = []() {
|
||||
using WarpTile = typename BlockGemmShape::WarpTile;
|
||||
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
|
||||
constexpr index_t kMaxKWarpTile = (sizeof(BDataType) == 1) ? 64 : 32;
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
return false;
|
||||
else if constexpr(kKWarpTile > kMaxKWarpTile)
|
||||
return false;
|
||||
else
|
||||
return std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
}();
|
||||
@@ -93,19 +107,21 @@ struct GemmPipelineAgBgCrImplBase
|
||||
load_tile(dst_block_tile, lds_tile_window);
|
||||
}
|
||||
|
||||
template <typename OverrideADataType = ADataType, typename OverrideBDataType = BDataType>
|
||||
CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const
|
||||
{
|
||||
// A tile in LDS
|
||||
ADataType* __restrict__ p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
|
||||
OverrideADataType* __restrict__ p_a_lds = static_cast<OverrideADataType*>(p_smem);
|
||||
constexpr auto a_lds_block_desc =
|
||||
Policy::template MakeALdsBlockDescriptor<Problem, OverrideADataType>();
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
// TODO: LDS alignment should come from Policy!
|
||||
constexpr index_t a_lds_block_space_size_aligned = integer_least_multiple(
|
||||
sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16);
|
||||
sizeof(OverrideADataType) * a_lds_block_desc.get_element_space_size(), 16);
|
||||
|
||||
// B tile in LDS
|
||||
BDataType* __restrict__ p_b_lds = static_cast<BDataType*>(
|
||||
OverrideBDataType* __restrict__ p_b_lds = static_cast<OverrideBDataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
@@ -18,7 +18,8 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked;
|
||||
static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked;
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem,
|
||||
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
|
||||
@@ -37,11 +37,22 @@ struct UniversalGemmBasePolicy
|
||||
#if defined(__gfx950__)
|
||||
// The combination of pk_int4_t and transposed loading causes numerical errors.
|
||||
// Therefore do not use transposed loading in this case.
|
||||
// Also, transpose load (ds_read_tr) requires specific tile distribution patterns
|
||||
// that only work for certain K warp tile sizes based on data type size:
|
||||
// - For 1-byte types (fp8/bf8): K warp tile <= 64
|
||||
// - For 2-byte types (fp16/bf16): K warp tile <= 32
|
||||
template <typename Problem>
|
||||
static constexpr bool is_a_load_tr = []() {
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
|
||||
// Max K warp tile for transpose load based on data type size
|
||||
constexpr index_t kMaxKWarpTile = (sizeof(ADataType) == 1) ? 64 : 32;
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
return false;
|
||||
else if constexpr(kKWarpTile > kMaxKWarpTile)
|
||||
return false;
|
||||
else
|
||||
return std::is_same_v<remove_cvref_t<typename Problem::ALayout>,
|
||||
tensor_layout::gemm::ColumnMajor>;
|
||||
@@ -49,9 +60,15 @@ struct UniversalGemmBasePolicy
|
||||
|
||||
template <typename Problem>
|
||||
static constexpr bool is_b_load_tr = []() {
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
|
||||
// Max K warp tile for transpose load based on data type size
|
||||
constexpr index_t kMaxKWarpTile = (sizeof(BDataType) == 1) ? 64 : 32;
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
return false;
|
||||
else if constexpr(kKWarpTile > kMaxKWarpTile)
|
||||
return false;
|
||||
else
|
||||
return std::is_same_v<remove_cvref_t<typename Problem::BLayout>,
|
||||
tensor_layout::gemm::RowMajor>;
|
||||
@@ -87,13 +104,12 @@ struct UniversalGemmBasePolicy
|
||||
return DefaultBTileAccessPattern;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem,
|
||||
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
|
||||
CK_TILE_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using ADataType = OverrideADataType;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
Reference in New Issue
Block a user