fix formatting

This commit is contained in:
Sami Remes
2025-10-27 14:28:06 +00:00
parent 98deefac3e
commit 2d86cd0081
2 changed files with 20 additions and 15 deletions

View File

@@ -345,12 +345,14 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
constexpr index_t reg_offset = [&]() {
if constexpr(Traits::NQPerBlock >= Traits::NIterPerWarp)
{
// Each nIter and warp/thread has its own scale - tile dstr handles the proper loading
// Each nIter and warp/thread has its own scale - tile dstr handles
// the proper loading
return nIter * Traits::BQPerBlock + kQScale;
}
else
{
// Many N warps/iters share the same scale, index from full [NQPerBlock=1, BQPerBlock] matrix
// Many N warps/iters share the same scale, index from full
// [NQPerBlock=1, BQPerBlock] matrix
static_assert(Traits::NQPerBlock == 1);
return kQScale;
}

View File

@@ -216,12 +216,13 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
constexpr index_t XR = get_warp_size() / NQPerIter;
static_assert(YPerTile == NQPerIter * NWarps * NIterPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, XR>,
tuple<sequence<NIterPerWarp, NWarps, NQPerIter>, sequence<XPerTile>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 2>,
sequence<0, 0>>{});
tile_distribution_encoding<
sequence<MWarps, XR>,
tuple<sequence<NIterPerWarp, NWarps, NQPerIter>, sequence<XPerTile>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 2>,
sequence<0, 0>>{});
}
else if constexpr(YPerTile >= NIterPerWarp)
{
@@ -230,16 +231,18 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
constexpr index_t XR = get_warp_size() / NQPerIter;
static_assert(YPerTile == NQPerIter * NIterPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NWarps, XR>,
tuple<sequence<NIterPerWarp, NQPerIter>, sequence<XPerTile>>,
tuple<sequence<0, 0>, sequence<0, 1>>,
tuple<sequence<0, 1>, sequence<2, 1>>,
sequence<1, 2>,
sequence<0, 0>>{});
tile_distribution_encoding<
sequence<MWarps, NWarps, XR>,
tuple<sequence<NIterPerWarp, NQPerIter>, sequence<XPerTile>>,
tuple<sequence<0, 0>, sequence<0, 1>>,
tuple<sequence<0, 1>, sequence<2, 1>>,
sequence<1, 2>,
sequence<0, 0>>{});
}
else
{
// larger NQ block size, multiple iters/warps use same scales -> replicate to all threads
// larger NQ block size, multiple iters/warps use same scales -> replicate to all
// threads
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NWarps, get_warp_size()>,
tuple<sequence<YPerTile>, sequence<XPerTile>>,