[rocm-libraries] ROCm/rocm-libraries#4354 (commit d41f08a)

[CK TILE] fix numerical errors of preshuffle_b

This pull request introduces several improvements and fixes related to
quantized grouped GEMM (General Matrix Multiply) pipelines and their
supporting utilities.

# The numerical issue

## Steps to reproduce
```bash
Run
./bin/tile_example_gemm_weight_preshuffle -prec=fp8
./bin/tile_example_gemm_weight_preshuffle -prec=int4
```

# Solution
The main changes address type correctness, improve data layout and
shuffling logic, and expand test coverage to better validate different
GEMM configurations.

**Key changes include:**

### Data layout and shuffling logic

* Refactored the logic in `shuffle_b_permuteN` to use `constexpr`
variables for `KLane` and `ItemsPerAccess`, simplifying tile view
construction and correcting the permutation order for improved
efficiency and correctness (`tensor_shuffle_utils.hpp`).
* Fixed the calculation of `KLaneBytes` in weight preshuffle pipeline
policies to account for internal data type conversion (e.g., from
`pk_int4_t` to `fp8`), ensuring accurate memory access and alignment in
quantized GEMM policies (`wp_pipeline_agmem_bgmem_creg_base_policy.hpp`,
`gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp`).
[[1]](diffhunk://#diff-93f16cd76e6e24404777e682a5ac8e039913ddd6a438c7efd61fdda42276e4efL274-R275)
[[2]](diffhunk://#diff-9c3d0fc3c014feed435bfd93ba1f8f9fb3e054dcc322deada3addf70bee5a58cL100-R105)

### Test infrastructure enhancements

* Unit tests did not catch this issue since there were no tests for fp8.
Added new configuration structs (`config_mn_16x16`, `config_mn_32x32`)
to support additional GEMM tile shapes and updated tests to run with
these configurations for broader coverage
(`test_gemm_pipeline_util.hpp`).
[[1]](diffhunk://#diff-5a5962b2c4aa7f6a87d1d6201ad383135e30df13b42654e997d870d57420d5b8R86-R103)
[[2]](diffhunk://#diff-5a5962b2c4aa7f6a87d1d6201ad383135e30df13b42654e997d870d57420d5b8L255-R269)

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Cong Ma
2026-02-11 07:05:46 +00:00
committed by assistant-librarian[bot]
parent 807efa703a
commit d06f35027a
7 changed files with 55 additions and 42 deletions

View File

@@ -75,8 +75,8 @@ float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline =
GemmQuantConfig<QuantMode>::template BaseGemmPipeline<GemmPipelineProblem,
GemmConfig::PreshuffleB>;
typename GemmQuantConfig<QuantMode>::template BaseGemmPipeline<GemmPipelineProblem,
GemmConfig::PreshuffleB>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
@@ -108,8 +108,8 @@ float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
tail_number_v>;
using GemmPipeline =
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
typename GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
@@ -227,8 +227,9 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
BQuantGroupSize,
GemmConfig::TransposeC>;
using GemmPipeline = GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
using GemmPipeline =
typename GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,