From 13cf0bd17f60c0a1706c87997e2e172c01cd95fd Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Wed, 12 Nov 2025 00:23:57 +0800 Subject: [PATCH] [CK_TILE] Fix gemm_quant (#3186) [ROCm/composable_kernel commit: 1b1c46e508c1fd40a03f54114b6b78629032fb4f] --- .../38_block_scale_gemm/CMakeLists.txt | 2 +- .../38_block_scale_gemm/gemm_quant_basic.cpp | 4 + .../38_block_scale_gemm/gemm_utils.hpp | 8 ++ include/ck_tile/host/tensor_shuffle_utils.hpp | 98 ++++++++++++++----- .../gemm/warp/warp_gemm_attribute_wmma.hpp | 1 + ...ock_universal_gemm_ar_flatbr_bquant_cr.hpp | 4 +- .../block_universal_gemm_as_aquant_bs_cr.hpp | 11 +-- .../block_universal_gemm_as_bs_bquant_cr.hpp | 6 +- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 5 +- .../pipeline/tile_gemm_quant_traits.hpp | 5 +- test/ck_tile/gemm_block_scale/CMakeLists.txt | 2 +- .../gemm_block_scale/test_gemm_quant_base.hpp | 14 ++- .../test_gemm_quant_fixtures.hpp | 24 +++-- 13 files changed, 135 insertions(+), 49 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 7358d4d749..b1ae9369a2 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -5,7 +5,7 @@ endif() list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) -if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp) target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp index b22596537f..d605a2b780 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp @@ -419,6 +419,10 @@ int dispatch_group_size_ct(int m, int n, int k, F&& f) int main(int argc, char* argv[]) { +#if CK_TILE_USE_WMMA + return !run_gemm_example(argc, argv); +#else // Use non-preshuffled GemmConfig for 2D block scale support return !run_gemm_example(argc, argv); +#endif } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 589caf88f4..1839c7f98d 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -216,6 +216,14 @@ struct GemmConfigBQuantPrefill : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; +template +struct GemmConfigBQuantPrefill_Wmma : public GemmConfigBQuantPrefill +{ + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; +}; + template auto shuffle_b(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - k_ / GemmConfig::K_Warp_Tile, - divisor, - GemmConfig::K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + if(ck_tile::is_gfx12_supported()) + { + constexpr int divisor = 2; + constexpr int kABK1PerLane = 8; + constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + kABK0PerLane, + divisor, + kABK1PerLane}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); + } + else + { + int divisor = 1; + if(ck_tile::is_gfx11_supported()) + { + divisor = 1; + } + else + { + assert(is_wave32() == false); + divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + } + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + divisor, + GemmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } } template @@ -55,21 +82,46 @@ template auto shuffle_b_permuteN(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; int k_ = t.get_lengths()[0]; - constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; - - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, - GemmConfig::N_Warp, - GemmConfig::N_Warp_Tile, - NRepeat, - k_ / GemmConfig::K_Warp_Tile, - divisor, - GemmConfig::K_Warp_Tile / divisor}); - - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); + if(ck_tile::is_gfx12_supported()) + { + constexpr int divisor = 2; + constexpr int kABK1PerLane = 8; + constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, + GemmConfig::N_Warp, + GemmConfig::N_Warp_Tile, + NRepeat, + k_ / GemmConfig::K_Warp_Tile, + kABK0PerLane, + divisor, + kABK1PerLane}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 6, 5, 2, 7}); + } + else + { + int divisor = 1; + if(ck_tile::is_gfx11_supported()) + { + divisor = 1; + } + else + { + assert(is_wave32() == false); + divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + } + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, + GemmConfig::N_Warp, + GemmConfig::N_Warp_Tile, + NRepeat, + k_ / GemmConfig::K_Warp_Tile, + divisor, + GemmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); + } } } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp index 90f6204ff3..dd2931f6b7 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp @@ -79,6 +79,7 @@ struct WarpGemmAttributeWmma static constexpr index_t kM = Impl::kM; static constexpr index_t kN = Impl::kN; static constexpr index_t kK = Impl::kK; + static constexpr index_t kCMLane = Impl::kCMLane; static constexpr index_t kKPerThread = Impl::kABK0PerLane * Impl::kABK1PerLane; CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp index 2d92745f75..6422c07e1d 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp @@ -82,11 +82,11 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg float scale_reg_f = 0.f; if constexpr(std::is_same_v) { - scale_reg_f = element_wise::amd_assembly_fp8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { - scale_reg_f = element_wise::amd_assembly_bf8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 1f72f4dc12..bbdd3128bf 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -25,13 +25,11 @@ struct BlockGemmAQuantBase float scale_reg_f = 0.f; if constexpr(std::is_same_v) { - scale_reg_f = - ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { - scale_reg_f = - ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { @@ -349,7 +347,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase // Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1, // 0] from thread 1, ..., and AQ_tile[3, 0] from thread 3. - constexpr uint32_t kTileRowsOfCPerThread = 4; + constexpr uint32_t kTileRowsOfCPerThread = (get_warp_size() == 64) ? 4 : 8; decltype(threadIdx.x) pull_from_lane = 0; if constexpr(WarpGemm::kM == 16) { @@ -410,7 +408,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase // desired row coefficient auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset]; - constexpr uint32_t kTileRows = 4; + constexpr uint32_t kTileRows = (get_warp_size() == 64) ? 4 : 8; + ; constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows; constexpr uint32_t reg_offset_for_row_data = c_row * WarpGemm::kCMLane; // Multiply by 4 because output is stored in tiles of 4 diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 660c30aa6e..28ae709bf0 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -25,13 +25,11 @@ struct BlockGemmBQuantBase float scale_reg_f = 0.f; if constexpr(std::is_same_v) { - scale_reg_f = - ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { - scale_reg_f = - ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 36cbb87877..15d2727f3b 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -240,7 +240,10 @@ struct QuantGemmKernel return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + CK_TILE_HOST static auto BlockSize() + { + return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize); + } CK_TILE_HOST static constexpr QuantGemmKernelArgs MakeKernelArgs(const QuantGemmHostArgs& hostArgs) diff --git a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp index c4429b76f9..3a5b86382d 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp @@ -41,7 +41,8 @@ template + bool UsePersistentKernel_ = false, + int VectorSize_ = 16> struct TileGemmQuantTraits { static constexpr bool kPadM = kPadM_; @@ -50,7 +51,7 @@ struct TileGemmQuantTraits static constexpr QuantType kQuantType = QuantType_; - static constexpr int _VectorSize = 16; + static constexpr int _VectorSize = VectorSize_; static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; using ALayout = ALayout_; diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 3a49e69c37..1c4a25c8bd 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -5,7 +5,7 @@ endif() list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) -if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") # Typed Test Suite for GEMM Quantization add_gtest_executable(test_tile_gemm_quant_typed test_gemm_quant_typed.cpp diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 6454101daf..6226a2de9e 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -69,7 +69,15 @@ class TestCkTileGemmQuantBase : public ::testing::Test constexpr bool kPadM = false; constexpr bool kPadN = false; constexpr bool kPadK = false; - + // WP pipeline requires per-thread tile size aligned to Problem::VectorLoadSize. + // static_assert((WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize) % + // VectorLoadSize == 0). gfx9 cards match the requirements but it fails on gfx12. so we only + // need to check the limitation on RDNA cards, i.e. assume wave size is 32. + constexpr ck_tile::index_t WaveSize = 32; + constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * M_Warp_Tile); + constexpr bool SupportVectorSize16 = + (M_Warp_Tile * K_Warp_Tile * sizeof(ADataType) * MIterPerWarp / WaveSize) % 16 == 0; + constexpr int VectorSize = PreshuffleB ? (SupportVectorSize16 ? 16 : 8) : 16; using CodegenGemmShape = ck_tile::TileGemmShape, ck_tile::sequence, @@ -89,7 +97,9 @@ class TestCkTileGemmQuantBase : public ::testing::Test ALayout, BLayout, GemmConfig::TransposeC, - DoubleSmemBuffer>; + DoubleSmemBuffer, + false, + VectorSize>; // Let the derived class create the appropriate pipeline and epilogue static_cast(this) diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index cabc0ec02c..5aac095514 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -7,6 +7,16 @@ #include "ck_tile/host/permute_pk_int4.hpp" #include "ck_tile/host/tensor_shuffle_utils.hpp" +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if CK_TILE_USE_WMMA + return 16; +#else + return is_8bit ? 64 : 32; +#endif +} + struct GemmConfigBase { static constexpr bool kPadM = false; @@ -40,7 +50,7 @@ struct GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; struct GemmConfigPreshuffleQuant : public GemmConfigBase @@ -75,7 +85,7 @@ struct GemmConfigPreshuffleBDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 64; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; struct GemmConfigPreshuffleBPrefill : public GemmConfigBase @@ -94,7 +104,7 @@ struct GemmConfigPreshuffleBPrefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 64; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBPrefill @@ -132,7 +142,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase