mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
[CK TILE] Support fp8/fp16 with pk_int4_t as data types for tensors A and B (#2805)
- Add support for tensor A/B in both fp16+pk_int4_t and fp8+pk_int4_t formats - Implement A(bf8) B(i4) support in universal GEMM - Use new implementation for i4 to fp8 conversion in Block Scale
This commit is contained in:
@@ -181,9 +181,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
static constexpr index_t MWarp = Traits::MWarp;
|
||||
static constexpr index_t NWarp = Traits::NWarp;
|
||||
|
||||
static constexpr auto Scheduler = Traits::Scheduler;
|
||||
static constexpr uint8_t kA_cvt_scale = std::is_same_v<ADataType, pk_int4_t> ? 16 : 1;
|
||||
static constexpr uint8_t kB_cvt_scale = std::is_same_v<BDataType, pk_int4_t> ? 16 : 1;
|
||||
static constexpr auto Scheduler = Traits::Scheduler;
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
@@ -451,7 +449,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] *
|
||||
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
|
||||
scale_reg_f);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -471,7 +469,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
[&](auto c_row) {
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] *
|
||||
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
|
||||
scale_reg_f);
|
||||
});
|
||||
}
|
||||
else
|
||||
@@ -556,7 +554,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
reg_offset_for_row_data] +=
|
||||
(c_warp_tensor
|
||||
.get_thread_buffer()[reg_offset_for_row_data] *
|
||||
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
|
||||
scale_reg_f);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,9 +179,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
static constexpr index_t MWarp = Traits::MWarp;
|
||||
static constexpr index_t NWarp = Traits::NWarp;
|
||||
|
||||
static constexpr auto Scheduler = Traits::Scheduler;
|
||||
static constexpr uint8_t kA_cvt_scale = std::is_same_v<ADataType, pk_int4_t> ? 16 : 1;
|
||||
static constexpr uint8_t kB_cvt_scale = std::is_same_v<BDataType, pk_int4_t> ? 16 : 1;
|
||||
static constexpr auto Scheduler = Traits::Scheduler;
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
@@ -384,8 +382,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg);
|
||||
static_for<0, WarpGemm::kM / 2, 1>{}([&](auto c_row) {
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f *
|
||||
kA_cvt_scale * kB_cvt_scale);
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user