mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
optimize cvt_pkf4_to_f16 implementation
This commit is contained in:
@@ -709,6 +709,37 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
uint32_t uscale = uint32_t(scale.data) << float_mantissa;
|
||||
|
||||
using ComputeV2Type =
|
||||
std::conditional_t<std::is_same_v<ComputeType, half_t>, fp16x2_t, bf16x2_t>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
auto pk_mxfp4x4_to_compute_v2 = [](auto pk_mxfp4x4, float fscale, auto byte_idx) {
|
||||
if constexpr(std::is_same_v<ComputeType, half_t>)
|
||||
{
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(
|
||||
pk_mxfp4x4, fscale, int(byte_idx));
|
||||
}
|
||||
else if constexpr(std::is_same_v<ComputeType, bf16_t>)
|
||||
{
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(
|
||||
pk_mxfp4x4, fscale, int(byte_idx));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(sizeof(pk_mxfp4x4) == 0, "unsupported compute type");
|
||||
}
|
||||
};
|
||||
static_for<0, PackedCnt, 1>{}([&](auto i) {
|
||||
dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
|
||||
i,
|
||||
pk_mxfp4x4_to_compute_v2(
|
||||
reinterpret_cast<const thread_buffer<uint32_t, XDL_PerWeightK>&>(
|
||||
quant_weight_tensor)
|
||||
.get(quant_idx_k),
|
||||
bit_cast<float>(uscale),
|
||||
i));
|
||||
});
|
||||
#else
|
||||
auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) {
|
||||
if constexpr(std::is_same_v<ComputeType, half_t>)
|
||||
{
|
||||
@@ -723,9 +754,6 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
static_assert(sizeof(pk_mxfp4) == 0, "unsupported compute type");
|
||||
}
|
||||
};
|
||||
using ComputeV2Type =
|
||||
std::conditional_t<std::is_same_v<ComputeType, half_t>, fp16x2_t, bf16x2_t>;
|
||||
|
||||
static_for<0, PackedCnt, 1>{}([&](auto i) {
|
||||
dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
|
||||
i,
|
||||
@@ -733,6 +761,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
quant_weight_tensor.get_thread_buffer()[quant_idx_k * PackedCnt + i],
|
||||
bit_cast<float>(uscale)));
|
||||
});
|
||||
#endif
|
||||
};
|
||||
|
||||
// MAIN LOOP
|
||||
|
||||
Reference in New Issue
Block a user