Enable NWarps replication for bquant tile dstr

This commit is contained in:
Sami Remes
2025-10-27 14:09:07 +00:00
parent 37738e4cb8
commit 98deefac3e
3 changed files with 41 additions and 32 deletions

View File

@@ -342,25 +342,17 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
});
// Multiply bquant with accumulated C
const index_t reg_offset = [&]() {
constexpr bool scale_per_niter_per_warp = Traits::QuantGroupSize::kN == 1 || Traits::NQPerBlock >= Traits::NIterPerWarp * Traits::NWarp;
if constexpr(scale_per_niter_per_warp)
constexpr index_t reg_offset = [&]() {
if constexpr(Traits::NQPerBlock >= Traits::NIterPerWarp)
{
// Each nIter and warp/thread has its own scale - tile dstr handles the proper loading
return nIter * Traits::BQPerBlock + kQScale;
}
else
{
// Many warps/iters can share the same scale, index from full [NQPerBlock, BQPerBlock] matrix
const index_t n_idx_of_warp =
nIter * WarpGemm::kN * NWarp + get_warp_id() * WarpGemm::kN;
const index_t row_index =
n_idx_of_warp / Traits::QuantGroupSize::kN;
if(get_lane_id() == 0)
{
printf("row_index: %d\n", row_index);
}
return row_index * Traits::BQPerBlock + kQScale;
// Many N warps/iters share the same scale, index from full [NQPerBlock=1, BQPerBlock] matrix
static_assert(Traits::NQPerBlock == 1);
return kQScale;
}
}();

View File

@@ -223,20 +223,20 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
sequence<1, 2>,
sequence<0, 0>>{});
}
// else if constexpr(YPerTile >= NIterPerWarp)
// {
// // now all NWarps have the same scale -> replicate
// constexpr index_t NQPerIter = integer_divide_ceil(YPerTile, NIterPerWarp);
// constexpr index_t XR = get_warp_size() / NQPerIter;
// static_assert(YPerTile == NQPerIter * NWarps * NIterPerWarp);
// return make_static_tile_distribution(
// tile_distribution_encoding<sequence<MWarps, NWarps, XR>,
// tuple<sequence<NIterPerWarp, NQPerIter>, sequence<XPerTile>>,
// tuple<sequence<0, 0>, sequence<0, 1>>,
// tuple<sequence<0, 1>, sequence<2, 1>>,
// sequence<1, 2>,
// sequence<0, 0>>{});
// }
else if constexpr(YPerTile >= NIterPerWarp)
{
// now all NWarps have the same scale -> replicate
constexpr index_t NQPerIter = integer_divide_ceil(YPerTile, NIterPerWarp);
constexpr index_t XR = get_warp_size() / NQPerIter;
static_assert(YPerTile == NQPerIter * NIterPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NWarps, XR>,
tuple<sequence<NIterPerWarp, NQPerIter>, sequence<XPerTile>>,
tuple<sequence<0, 0>, sequence<0, 1>>,
tuple<sequence<0, 1>, sequence<2, 1>>,
sequence<1, 2>,
sequence<0, 0>>{});
}
else
{
// larger NQ block size, multiple iters/warps use same scales -> replicate to all threads