[rocm-libraries] ROCm/rocm-libraries#4354 (commit d41f08a)

[CK TILE] fix numerical errors of preshuffle_b

This pull request introduces several improvements and fixes related to
quantized grouped GEMM (General Matrix Multiply) pipelines and their
supporting utilities.

# The numerical issue

## Steps to reproduce
```bash
Run
./bin/tile_example_gemm_weight_preshuffle -prec=fp8
./bin/tile_example_gemm_weight_preshuffle -prec=int4
```

# Solution
The main changes address type correctness, improve data layout and
shuffling logic, and expand test coverage to better validate different
GEMM configurations.

**Key changes include:**

### Data layout and shuffling logic

* Refactored the logic in `shuffle_b_permuteN` to use `constexpr`
variables for `KLane` and `ItemsPerAccess`, simplifying tile view
construction and correcting the permutation order for improved
efficiency and correctness (`tensor_shuffle_utils.hpp`).
* Fixed the calculation of `KLaneBytes` in weight preshuffle pipeline
policies to account for internal data type conversion (e.g., from
`pk_int4_t` to `fp8`), ensuring accurate memory access and alignment in
quantized GEMM policies (`wp_pipeline_agmem_bgmem_creg_base_policy.hpp`,
`gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp`).
[[1]](diffhunk://#diff-93f16cd76e6e24404777e682a5ac8e039913ddd6a438c7efd61fdda42276e4efL274-R275)
[[2]](diffhunk://#diff-9c3d0fc3c014feed435bfd93ba1f8f9fb3e054dcc322deada3addf70bee5a58cL100-R105)

### Test infrastructure enhancements

* Unit tests did not catch this issue since there were no tests for fp8.
Added new configuration structs (`config_mn_16x16`, `config_mn_32x32`)
to support additional GEMM tile shapes and updated tests to run with
these configurations for broader coverage
(`test_gemm_pipeline_util.hpp`).
[[1]](diffhunk://#diff-5a5962b2c4aa7f6a87d1d6201ad383135e30df13b42654e997d870d57420d5b8R86-R103)
[[2]](diffhunk://#diff-5a5962b2c4aa7f6a87d1d6201ad383135e30df13b42654e997d870d57420d5b8L255-R269)

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Cong Ma
2026-02-11 07:05:46 +00:00
committed by assistant-librarian[bot]
parent 807efa703a
commit d06f35027a
7 changed files with 55 additions and 42 deletions

View File

@@ -75,8 +75,8 @@ float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline =
GemmQuantConfig<QuantMode>::template BaseGemmPipeline<GemmPipelineProblem,
GemmConfig::PreshuffleB>;
typename GemmQuantConfig<QuantMode>::template BaseGemmPipeline<GemmPipelineProblem,
GemmConfig::PreshuffleB>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
@@ -108,8 +108,8 @@ float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
tail_number_v>;
using GemmPipeline =
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
typename GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
@@ -227,8 +227,9 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
BQuantGroupSize,
GemmConfig::TransposeC>;
using GemmPipeline = GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
using GemmPipeline =
typename GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,

View File

@@ -164,25 +164,17 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmC
}
else
{
int divisor = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
}
else
{
assert(is_wave32() == false);
divisor = get_warp_size() / gemmConfig.N_Warp_Tile;
}
constexpr int KLane = ck_tile::get_warp_size() / GemmConfig::N_Warp_Tile;
constexpr int ItemsPerAccess =
std::min(16 / static_cast<int>(sizeof(T)), GemmConfig::K_Warp_Tile / KLane);
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / gemmConfig.K_Warp_Tile,
divisor,
gemmConfig.K_Warp_Tile / divisor});
k_ / ItemsPerAccess,
ItemsPerAccess});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 2, 5});
}
}

View File

@@ -271,20 +271,19 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
constexpr index_t WaveSize = get_warp_size();
constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize;
using BDataType = typename Problem::BDataType;
constexpr index_t KLaneBytes =
KLane / numeric_traits<BDataType>::PackedSize * sizeof(BDataType);
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
NumAccess>;
// When BDataType is pk_int4_t, it is internally converted to fp8 for computation.
constexpr index_t KLaneBytes = KLane * sizeof(BTypeToUse);
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
NumAccess>;
using BlockWeightPreshufflePolicy =
BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,

View File

@@ -131,6 +131,10 @@ template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 64, false> { using Ty
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; };
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; };
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<>; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 64, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<EDouble>; };
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 64, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<EDouble>; };
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 64, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<EDouble>; };
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 64, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<EDouble>; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 64, false, false, false, EQuad> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<EQuad>; };
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 64, false, false, false, EQuad> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<EQuad>; };
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 64, false, false, false, EQuad> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<EQuad>; };

View File

@@ -65,8 +65,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
// A/B DataType gets converted from PkInt4/PkFp4 during loading
using OverrideADataType = BlockGemm::OverrideADataType;
using OverrideBDataType = BlockGemm::OverrideBDataType;
using OverrideADataType = typename BlockGemm::OverrideADataType;
using OverrideBDataType = typename BlockGemm::OverrideBDataType;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;

View File

@@ -97,10 +97,13 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel
constexpr index_t WaveSize = get_warp_size();
constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize;
using BDataType = typename Problem::BDataType;
constexpr index_t KLaneBytes =
KLane / numeric_traits<BDataType>::PackedSize * sizeof(BDataType);
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));
// When BDataType is pk_int4_t, it is internally converted to fp8 for computation.
using BTypeToUse = mixed_prec_compute_type_from_input_t<typename Problem::BDataType,
typename Problem::ADataType,
typename Problem::ComputeDataType>;
constexpr index_t KLaneBytes = KLane * sizeof(BTypeToUse);
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,

View File

@@ -83,10 +83,22 @@ struct config
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
};
template <typename Datatype>
struct config_mn_32x32 : public config<Datatype>
{
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<Datatype, M_Warp_Tile>();
};
template <typename Datatype>
struct config_mn_16x16 : public config<Datatype>
{
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<Datatype, M_Warp_Tile>();
};
template <typename Datatype>
@@ -252,7 +264,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
RunSingle<config_wmma<ADataType>, PadM, PadN, PadK, Preshuffle>(
M, N, K, StrideA, StrideB, StrideC, kb);
#else
RunSingle<config<ADataType>, PadM, PadN, PadK, Preshuffle>(
RunSingle<config_mn_16x16<ADataType>, PadM, PadN, PadK, Preshuffle>(
M, N, K, StrideA, StrideB, StrideC, kb);
RunSingle<config_mn_32x32<ADataType>, PadM, PadN, PadK, Preshuffle>(
M, N, K, StrideA, StrideB, StrideC, kb);
#endif
}