Fix quant scale matrix layout for block scale gemm (#3079)

* Adding support for TiledPermuteN

* Adding test

* moving shuffle functions to common place

* resolving commit hook

* fix formatting
This commit is contained in:
Khushbu Agarwal
2025-10-27 13:56:07 -07:00
committed by GitHub
parent a46b725992
commit b11f53a484
10 changed files with 35 additions and 31 deletions

View File

@@ -307,6 +307,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as corresponding "
"C block tensor data type!");
constexpr auto warp_size = get_warp_size();
// hot loop:
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
@@ -365,10 +366,11 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg);
static_for<0, WarpGemm::kM / 2, 1>{}([&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
});
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
});
});
});
});