optimize cvt_pkf4_to_f16 implementation

This commit is contained in:
Feng Shijie
2025-08-20 04:39:14 +00:00
parent 3ca0bd500a
commit c27eb0771a

View File

@@ -709,6 +709,37 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
uint32_t uscale = uint32_t(scale.data) << float_mantissa;
using ComputeV2Type =
std::conditional_t<std::is_same_v<ComputeType, half_t>, fp16x2_t, bf16x2_t>;
#if defined(__gfx950__)
auto pk_mxfp4x4_to_compute_v2 = [](auto pk_mxfp4x4, float fscale, auto byte_idx) {
if constexpr(std::is_same_v<ComputeType, half_t>)
{
return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(
pk_mxfp4x4, fscale, int(byte_idx));
}
else if constexpr(std::is_same_v<ComputeType, bf16_t>)
{
return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(
pk_mxfp4x4, fscale, int(byte_idx));
}
else
{
static_assert(sizeof(pk_mxfp4x4) == 0, "unsupported compute type");
}
};
static_for<0, PackedCnt, 1>{}([&](auto i) {
dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
i,
pk_mxfp4x4_to_compute_v2(
reinterpret_cast<const thread_buffer<uint32_t, XDL_PerWeightK>&>(
quant_weight_tensor)
.get(quant_idx_k),
bit_cast<float>(uscale),
i));
});
#else
auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) {
if constexpr(std::is_same_v<ComputeType, half_t>)
{
@@ -723,9 +754,6 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
static_assert(sizeof(pk_mxfp4) == 0, "unsupported compute type");
}
};
using ComputeV2Type =
std::conditional_t<std::is_same_v<ComputeType, half_t>, fp16x2_t, bf16x2_t>;
static_for<0, PackedCnt, 1>{}([&](auto i) {
dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
i,
@@ -733,6 +761,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
quant_weight_tensor.get_thread_buffer()[quant_idx_k * PackedCnt + i],
bit_cast<float>(uscale)));
});
#endif
};
// MAIN LOOP