diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index b81c5de7ab..3a30c2bad3 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -74,7 +74,7 @@ User need to select correct mapping of config for each quant mode: |:--------|:-----:|:-----:|-------| | For selecting AQuant | aquant | gemm_aquant_quantgrouped.cpp| GemmConfigQuantDecode | | For selecting AQuant with Preshuffle quant | aquant | gemm_aquant_quantgrouped_preshufflequant.cpp | GemmConfigPreshuffleQuantDecode | -| For selecting BQuant | bquant | gemm_bquant_quantgrouped_.cpp| GemmConfigQuantDecode (or) GemmConfigBQuantPrefill | +| For selecting BQuant | bquant | gemm_bquant_quantgrouped_.cpp| GemmConfigQuantDecode (or) GemmConfigQuantPrefill | | For selecting BQuant with Preshuffle quant | bquant | gemm_bquant_quantgrouped_preshufflequant.cpp| GemmConfigPreshuffleQuantDecode (or) GemmConfigPreshuffleBQuantPrefill | | For selecting PreShuffle B with BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb.cpp| GemmConfigPreshuffleB_BQuant_Decode (or) GemmConfigPreshuffleB_BQuant_Prefill | For selecting PreShuffle B with preshuffle BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp |GemmConfigPreshuffleB_PreshuffleBQuant_Decode (or) GemmConfigPreshuffleB_PreshuffleBQuant_Prefill diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp index 0f75976602..ad1a4e0d10 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp @@ -6,6 +6,10 @@ template using GemmConfig = GemmConfigQuantDecode; +// GemmConfigQuantPrefill is also supported for aquant grouped quantization +// template +// using GemmConfig = GemmConfigQuantPrefill; + void aquant_quantgrouped_instance_factory( std::unordered_map>& lut) { diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp index 2dbae9e42c..61fd65960f 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigBQuantPrefill; +using GemmConfig = GemmConfigQuantPrefill; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp index 40cf88624b..1d471068eb 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigBQuantPrefill; +using GemmConfig = GemmConfigQuantPrefill; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp index 5c21d5aa16..280029033b 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigBQuantPrefill; +using GemmConfig = GemmConfigQuantPrefill; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp index 80b9a2765e..a277c864bb 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigBQuantPrefill; +using GemmConfig = GemmConfigQuantPrefill; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 81032d6452..116661c157 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -221,7 +221,7 @@ struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill }; template -struct GemmConfigBQuantPrefill : public GemmConfigBase +struct GemmConfigQuantPrefill : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -237,13 +237,13 @@ struct GemmConfigBQuantPrefill : public GemmConfigBase }; template -struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigBQuantPrefill +struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill { static constexpr bool PreshuffleQuant = true; }; template -struct GemmConfigBQuantPrefill_Wmma : public GemmConfigBQuantPrefill +struct GemmConfigBQuantPrefill_Wmma : public GemmConfigQuantPrefill { static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 06fc4d5bfa..ad4d0baab2 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -373,8 +373,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase { // Need to multiply aquant with accumulated C // - // The accumulated C tile has the standard distribution. For example - // lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0], + // The accumulated C tile has the standard distribution. For example, a + // 32x32 C lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0], // [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0], // [26,0], [27,0]. // @@ -388,35 +388,31 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase // // These scales can be obtained using __builtin_amdgcn_ds_bpermute. - // MIters per warp - constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM; - // Reg block offset based on mIter - constexpr index_t reg_block_offset = - ((mIter / mIters_per_warp) * Traits::AQPerBlock); - - constexpr index_t lane_base_offset = - (mIter % mIters_per_warp) * WarpGemm::kM; - - // Scale tensor offset along K - constexpr index_t src_reg_offset = reg_block_offset + kQScale; - // Directly index into thread buffer corresponding to - // desired row coefficient + // Each thread stores AQPerBlock scale values per M iteration. + constexpr index_t reg_block_offset = mIter * Traits::AQPerBlock; + constexpr index_t src_reg_offset = reg_block_offset + kQScale; auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset]; - constexpr uint32_t kTileRows = (get_warp_size() == 64) ? 4 : 8; - ; - constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows; - constexpr uint32_t reg_offset_for_row_data = c_row * WarpGemm::kCMLane; - // Multiply by 4 because output is stored in tiles of 4 - // x CNLane - constexpr uint32_t row_base = - ((reg_offset_for_row_data / kTiledCMsPerWarp) * kTiledCMsPerWarp) + - ((reg_offset_for_row_data % kTiledCMsPerWarp) / WarpGemm::kCMLane); + // Divide M dimension of C Warp tile into groups of + // (WarpGemm::kCMLane * WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane) + // m_base_offset_of_c_row indicates which group the current c_row belongs + // to. + constexpr index_t m_base_offset_of_c_row = + (c_row / WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane) * + (WarpGemm::kCMLane * WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane); + + // M offset of each thread within its group (see comment above) + index_t m_base_offset_of_lane = + (get_lane_id() / WarpGemm::kN * + WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane); + + // M offset wrt. c_row in the subgroup of kCM1PerLane + constexpr index_t m_offset_of_c_row = + c_row & (WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane - 1); - // Lane index to source scale from uint32_t src_lane_idx = - lane_base_offset + row_base + (__lane_id() / WarpGemm::kN * kTileRows); + m_base_offset_of_c_row + m_base_offset_of_lane + m_offset_of_c_row; return exchange_quant_value_across_lanes(scale_reg, src_lane_idx); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp index dae099af4f..f2142b4fdf 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp @@ -94,21 +94,20 @@ struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding // # of elements per thread constexpr index_t X = XPerTile; - constexpr index_t Y0 = 1; - constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1; - constexpr index_t Y2 = MWarps; - constexpr index_t Y3 = WarpGemm::kM; - static_assert(Y3 >= WarpGemm::kM, + constexpr index_t YR = 1; + constexpr index_t Y0 = MIterPerWarp ? MIterPerWarp : 1; + constexpr index_t Y1 = MWarps; + constexpr index_t Y2 = WarpGemm::kM; + static_assert(Y2 >= WarpGemm::kM, "Scales for all rows must be available within the warp."); - static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile, - "Y0, Y1, Y2, Y3 must cover the blocktile along Y."); + static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y."); return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 1>>, - tuple, sequence<0, 3>>, + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 1>>, + tuple, sequence<1, 2>>, sequence<1, 2>, - sequence<1, 0>>{}); + sequence<0, 0>>{}); } } }; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 5e610cb76b..685f52cdac 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -53,6 +53,13 @@ struct GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; +struct GemmConfigPrefill : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; +}; + struct GemmConfigPreshuffleQuant : public GemmConfigBase { static constexpr bool PreshuffleQuant = true; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp index a75d871421..07aed62804 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp @@ -39,6 +39,12 @@ using AQuantTypes = ::testing::Types< std::tuple, std::tuple, + // PreshuffleQuant = false && TransposeC = false && Prefill + std::tuple, + std::tuple, + std::tuple, + std::tuple, + // PreshuffleQuant = false && TransposeC = true std::tuple, std::tuple,