From bdbab2394bc1dbf0407597eac10aee867844b73e Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Thu, 2 Oct 2025 18:14:39 +0000 Subject: [PATCH] Merge commit '6fc28ab4934d3668bf4ec96db1e082cf26b11384' into develop --- .../block_universal_gemm_as_aquant_bs_cr.hpp | 41 ++++++++++++--- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 2 +- .../pipeline/tile_gemm_quant_traits.hpp | 3 +- test/ck_tile/gemm_block_scale/CMakeLists.txt | 4 +- .../gemm_block_scale/test_gemm_quant_base.hpp | 1 + .../test_gemm_quant_fixtures.hpp | 52 +++++++++++++++++-- .../test_gemm_quant_typed.cpp | 21 +++++++- 7 files changed, 109 insertions(+), 15 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index d4bece1a83..cb20bdbd50 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -346,13 +346,40 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase { if constexpr(Traits::TransposeC) // transposed C { - static_assert(false, - "It is not supported yet to enable both Preshuffle " - "and TransposeC."); - // TODO: - // A new tile distribution is needed for the Preshuffle and - // Transpose combination. For instance, with mnk at 16x16x32, lanes - // 0-15, 16-31, 32-47, and 48-63 must load the same elements of AQ. + constexpr auto tbuf_offset = number< + typename CBlockTensor::ThreadTensorDesc{}.calculate_offset( + merge_sequences(sequence{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter]; + auto pull_from_lane = (__lane_id() & (Traits::WarpGemm::kN - 1)) * + Traits::AQPerBlock + + kQScale; + + // cross lane ops + uint32_t scale_reg_dword; + + if constexpr(std::is_same_v) + { + scale_reg_dword = ck_tile::bit_cast(scale_reg); + } + else + { + scale_reg_dword = static_cast(scale_reg); + } + + int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( + pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword)); + + float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg); + + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * + scale_reg_f); + }); } else { diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index a0b6fc5821..bba2bc8400 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -73,7 +73,7 @@ struct is_quantpreshuffle_enabled }; template -struct is_quantpreshuffle_enabled +struct is_quantpreshuffle_enabled> { static constexpr bool value = T::PreshuffleQuant; }; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp index 52a326a897..c4429b76f9 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp @@ -39,6 +39,7 @@ template struct TileGemmQuantTraits @@ -62,7 +63,7 @@ struct TileGemmQuantTraits using AsLayout = ALayout_; using BsLayout = BLayout_; - static constexpr bool TransposeC = false; + static constexpr bool TransposeC = TransposeC_; static constexpr bool UseStructuredSparsity = false; static constexpr index_t NumWaveGroups = 1; static constexpr bool UsePersistentKernel = UsePersistentKernel_; diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 93a13ba5af..3a49e69c37 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -7,7 +7,9 @@ list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") # Typed Test Suite for GEMM Quantization - add_gtest_executable(test_tile_gemm_quant_typed test_gemm_quant_typed.cpp) + add_gtest_executable(test_tile_gemm_quant_typed + test_gemm_quant_typed.cpp + ) target_compile_options(test_tile_gemm_quant_typed PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile quant gemm tests for current target") diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 355e9fce32..80167a1d21 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -87,6 +87,7 @@ class TestCkTileGemmQuantBase : public ::testing::Test QuantType, ALayout, BLayout, + GemmConfig::TransposeC, DoubleSmemBuffer>; // Let the derived class create the appropriate pipeline and epilogue diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 98f88f4d53..21eabd6041 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -41,6 +41,22 @@ struct GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = 32; }; +struct GemmConfigPreshuffleQuant : public GemmConfigBase +{ + static constexpr bool PreshuffleQuant = true; +}; + +struct GemmConfigTransposeC : public GemmConfigBase +{ + static constexpr bool TransposeC = true; +}; + +struct GemmConfigPreshuffleQuantTransposeC : public GemmConfigBase +{ + static constexpr bool PreshuffleQuant = true; + static constexpr bool TransposeC = true; +}; + struct GemmConfigPreshuffleB { static constexpr bool kPadM = false; @@ -100,6 +116,24 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase + auto shuffle_aq(const ck_tile::HostTensor* t, int block_aq_k) + { + if(t->get_lengths().size() != 2) + { + throw std::runtime_error("Host tensor is not rank 2 tensor."); + } + int m_ = t->get_lengths()[0]; + int aqk_ = t->get_lengths()[1]; + if(aqk_ % block_aq_k != 0) + { + throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k."); + } + ck_tile::HostTensor t_view({m_, aqk_ / block_aq_k, block_aq_k}); + std::copy(t->begin(), t->end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {1, 0, 2}); + } + // AQuant-specific data generation void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K) { @@ -150,7 +184,17 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase aq_shuffle_host = + shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize); + aq_m_aqk_dev_buf.ToDevice(aq_shuffle_host.data()); + } + else + { + aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data()); + } b_k_n_dev_buf.ToDevice(b_k_n.data()); // Create args for kernel execution @@ -245,7 +289,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase; // Type combinations for each quantization type // clang-format off using AQuantTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false std::tuple, std::tuple, std::tuple, - std::tuple + std::tuple, + + // PreshuffleQuant = false && TransposeC = true + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // PreshuffleQuant = true && TransposeC = false + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // PreshuffleQuant = true && TransposeC = true + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on