From 98deefac3e95404cc3aef574cfd7b0c6beec6ea7 Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Mon, 27 Oct 2025 14:09:07 +0000 Subject: [PATCH] Enable NWarps replication for bquant tile dstr --- .../block_universal_gemm_as_bs_bquant_cr.hpp | 18 ++++-------- .../pipeline/gemm_group_quant_utils.hpp | 28 +++++++++---------- .../test_gemm_quant_typed.cpp | 27 ++++++++++++++---- 3 files changed, 41 insertions(+), 32 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index f71937ca50..8017a82bc6 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -342,25 +342,17 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase }); // Multiply bquant with accumulated C - const index_t reg_offset = [&]() { - constexpr bool scale_per_niter_per_warp = Traits::QuantGroupSize::kN == 1 || Traits::NQPerBlock >= Traits::NIterPerWarp * Traits::NWarp; - if constexpr(scale_per_niter_per_warp) + constexpr index_t reg_offset = [&]() { + if constexpr(Traits::NQPerBlock >= Traits::NIterPerWarp) { // Each nIter and warp/thread has its own scale - tile dstr handles the proper loading return nIter * Traits::BQPerBlock + kQScale; } else { - // Many warps/iters can share the same scale, index from full [NQPerBlock, BQPerBlock] matrix - const index_t n_idx_of_warp = - nIter * WarpGemm::kN * NWarp + get_warp_id() * WarpGemm::kN; - const index_t row_index = - n_idx_of_warp / Traits::QuantGroupSize::kN; - if(get_lane_id() == 0) - { - printf("row_index: %d\n", row_index); - } - return row_index * Traits::BQPerBlock + kQScale; + // Many N warps/iters share the same scale, index from full [NQPerBlock=1, BQPerBlock] matrix + static_assert(Traits::NQPerBlock == 1); + return kQScale; } }(); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp index 33b2e413e1..2dac5cddb1 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp @@ -223,20 +223,20 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding sequence<1, 2>, sequence<0, 0>>{}); } - // else if constexpr(YPerTile >= NIterPerWarp) - // { - // // now all NWarps have the same scale -> replicate - // constexpr index_t NQPerIter = integer_divide_ceil(YPerTile, NIterPerWarp); - // constexpr index_t XR = get_warp_size() / NQPerIter; - // static_assert(YPerTile == NQPerIter * NWarps * NIterPerWarp); - // return make_static_tile_distribution( - // tile_distribution_encoding, - // tuple, sequence>, - // tuple, sequence<0, 1>>, - // tuple, sequence<2, 1>>, - // sequence<1, 2>, - // sequence<0, 0>>{}); - // } + else if constexpr(YPerTile >= NIterPerWarp) + { + // now all NWarps have the same scale -> replicate + constexpr index_t NQPerIter = integer_divide_ceil(YPerTile, NIterPerWarp); + constexpr index_t XR = get_warp_size() / NQPerIter; + static_assert(YPerTile == NQPerIter * NIterPerWarp); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 1>>, + tuple, sequence<2, 1>>, + sequence<1, 2>, + sequence<0, 0>>{}); + } else { // larger NQ block size, multiple iters/warps use same scales -> replicate to all threads diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp index eb6114d08a..eb895227b3 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp @@ -22,7 +22,12 @@ using RowColQuant = std::integral_constant; using GroupSize = ck_tile::QuantGroupShape>; using GroupSize64 = ck_tile::QuantGroupShape>; -using GroupSize2D = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D8N = ck_tile::QuantGroupShape>; +using GroupSize2D16N = ck_tile::QuantGroupShape>; +using GroupSize2D64N = ck_tile::QuantGroupShape>; +using GroupSize2D128N = ck_tile::QuantGroupShape>; // Type combinations for each quantization type // clang-format off @@ -67,10 +72,22 @@ using BQuantTypes = ::testing::Types< std::tuple, // 2d cases with grouping also on the n axis - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on