mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Enable NWarps replication for bquant tile dstr
This commit is contained in:
@@ -342,25 +342,17 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
});
|
||||
|
||||
// Multiply bquant with accumulated C
|
||||
const index_t reg_offset = [&]() {
|
||||
constexpr bool scale_per_niter_per_warp = Traits::QuantGroupSize::kN == 1 || Traits::NQPerBlock >= Traits::NIterPerWarp * Traits::NWarp;
|
||||
if constexpr(scale_per_niter_per_warp)
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(Traits::NQPerBlock >= Traits::NIterPerWarp)
|
||||
{
|
||||
// Each nIter and warp/thread has its own scale - tile dstr handles the proper loading
|
||||
return nIter * Traits::BQPerBlock + kQScale;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Many warps/iters can share the same scale, index from full [NQPerBlock, BQPerBlock] matrix
|
||||
const index_t n_idx_of_warp =
|
||||
nIter * WarpGemm::kN * NWarp + get_warp_id() * WarpGemm::kN;
|
||||
const index_t row_index =
|
||||
n_idx_of_warp / Traits::QuantGroupSize::kN;
|
||||
if(get_lane_id() == 0)
|
||||
{
|
||||
printf("row_index: %d\n", row_index);
|
||||
}
|
||||
return row_index * Traits::BQPerBlock + kQScale;
|
||||
// Many N warps/iters share the same scale, index from full [NQPerBlock=1, BQPerBlock] matrix
|
||||
static_assert(Traits::NQPerBlock == 1);
|
||||
return kQScale;
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -223,20 +223,20 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
// else if constexpr(YPerTile >= NIterPerWarp)
|
||||
// {
|
||||
// // now all NWarps have the same scale -> replicate
|
||||
// constexpr index_t NQPerIter = integer_divide_ceil(YPerTile, NIterPerWarp);
|
||||
// constexpr index_t XR = get_warp_size() / NQPerIter;
|
||||
// static_assert(YPerTile == NQPerIter * NWarps * NIterPerWarp);
|
||||
// return make_static_tile_distribution(
|
||||
// tile_distribution_encoding<sequence<MWarps, NWarps, XR>,
|
||||
// tuple<sequence<NIterPerWarp, NQPerIter>, sequence<XPerTile>>,
|
||||
// tuple<sequence<0, 0>, sequence<0, 1>>,
|
||||
// tuple<sequence<0, 1>, sequence<2, 1>>,
|
||||
// sequence<1, 2>,
|
||||
// sequence<0, 0>>{});
|
||||
// }
|
||||
else if constexpr(YPerTile >= NIterPerWarp)
|
||||
{
|
||||
// now all NWarps have the same scale -> replicate
|
||||
constexpr index_t NQPerIter = integer_divide_ceil(YPerTile, NIterPerWarp);
|
||||
constexpr index_t XR = get_warp_size() / NQPerIter;
|
||||
static_assert(YPerTile == NQPerIter * NIterPerWarp);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, NWarps, XR>,
|
||||
tuple<sequence<NIterPerWarp, NQPerIter>, sequence<XPerTile>>,
|
||||
tuple<sequence<0, 0>, sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// larger NQ block size, multiple iters/warps use same scales -> replicate to all threads
|
||||
|
||||
Reference in New Issue
Block a user