mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Fix quant scale matrix layout for block scale gemm (#3079)
* Adding support for TiledPermuteN * Adding test * moving shuffle functions to common place * resolving commit hook * fix formatting
This commit is contained in:
@@ -307,6 +307,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as corresponding "
|
||||
"C block tensor data type!");
|
||||
constexpr auto warp_size = get_warp_size();
|
||||
|
||||
// hot loop:
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
@@ -365,10 +366,11 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg);
|
||||
static_for<0, WarpGemm::kM / 2, 1>{}([&](auto c_row) {
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
|
||||
});
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user