mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4354 (commit d41f08a)
[CK TILE] fix numerical errors of preshuffle_b This pull request introduces several improvements and fixes related to quantized grouped GEMM (General Matrix Multiply) pipelines and their supporting utilities. # The numerical issue ## Steps to reproduce ```bash Run ./bin/tile_example_gemm_weight_preshuffle -prec=fp8 ./bin/tile_example_gemm_weight_preshuffle -prec=int4 ``` # Solution The main changes address type correctness, improve data layout and shuffling logic, and expand test coverage to better validate different GEMM configurations. **Key changes include:** ### Data layout and shuffling logic * Refactored the logic in `shuffle_b_permuteN` to use `constexpr` variables for `KLane` and `ItemsPerAccess`, simplifying tile view construction and correcting the permutation order for improved efficiency and correctness (`tensor_shuffle_utils.hpp`). * Fixed the calculation of `KLaneBytes` in weight preshuffle pipeline policies to account for internal data type conversion (e.g., from `pk_int4_t` to `fp8`), ensuring accurate memory access and alignment in quantized GEMM policies (`wp_pipeline_agmem_bgmem_creg_base_policy.hpp`, `gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp`). [[1]](diffhunk://#diff-93f16cd76e6e24404777e682a5ac8e039913ddd6a438c7efd61fdda42276e4efL274-R275) [[2]](diffhunk://#diff-9c3d0fc3c014feed435bfd93ba1f8f9fb3e054dcc322deada3addf70bee5a58cL100-R105) ### Test infrastructure enhancements * Unit tests did not catch this issue since there were no tests for fp8. Added new configuration structs (`config_mn_16x16`, `config_mn_32x32`) to support additional GEMM tile shapes and updated tests to run with these configurations for broader coverage (`test_gemm_pipeline_util.hpp`). [[1]](diffhunk://#diff-5a5962b2c4aa7f6a87d1d6201ad383135e30df13b42654e997d870d57420d5b8R86-R103) [[2]](diffhunk://#diff-5a5962b2c4aa7f6a87d1d6201ad383135e30df13b42654e997d870d57420d5b8L255-R269) Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
committed by
assistant-librarian[bot]
parent
807efa703a
commit
d06f35027a
@@ -83,10 +83,22 @@ struct config
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
};
|
||||
|
||||
template <typename Datatype>
|
||||
struct config_mn_32x32 : public config<Datatype>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<Datatype, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename Datatype>
|
||||
struct config_mn_16x16 : public config<Datatype>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<Datatype, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename Datatype>
|
||||
@@ -252,7 +264,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
RunSingle<config_wmma<ADataType>, PadM, PadN, PadK, Preshuffle>(
|
||||
M, N, K, StrideA, StrideB, StrideC, kb);
|
||||
#else
|
||||
RunSingle<config<ADataType>, PadM, PadN, PadK, Preshuffle>(
|
||||
RunSingle<config_mn_16x16<ADataType>, PadM, PadN, PadK, Preshuffle>(
|
||||
M, N, K, StrideA, StrideB, StrideC, kb);
|
||||
RunSingle<config_mn_32x32<ADataType>, PadM, PadN, PadK, Preshuffle>(
|
||||
M, N, K, StrideA, StrideB, StrideC, kb);
|
||||
#endif
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user