Support fp8 dynamic quantization for fmha (#3206)

* Support qscale for dynamic quant, remove static quant

* Support hdim=256

* Remove bias test case for fp8

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
rocking
2025-11-24 16:28:25 +08:00
committed by GitHub
parent 096f0a3b23
commit 5948dbffe4
17 changed files with 369 additions and 280 deletions

View File

@@ -178,7 +178,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
uint64_t drop_offset,
bool drop_prefs,
std::string mask_str,
bool squant,
std::string qscale_str,
bool is_rotary_interleaved,
ck_tile::index_t num_splits,
std::string init_method,
@@ -380,6 +380,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
mask_info mask =
mask_info::decode(mask_str, seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore
quant_scale_info qscale = quant_scale_info::decode(qscale_str);
if(p_drop < 0.0f || p_drop > 1.0f)
{
std::cerr << "The value of p_drop should be 0~1" << std::endl;
@@ -572,6 +574,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
// TODO - change the tensor length for different quant scale
ck_tile::HostTensor<float> q_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
ck_tile::HostTensor<float> k_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
ck_tile::HostTensor<float> v_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
// batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host(
@@ -592,7 +599,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::HostTensor<int32_t> cache_batch_idx_host(use_cache_batch_idx
? std::array<ck_tile::index_t, 1>{batch}
: std::array<ck_tile::index_t, 1>{1});
float max_o = 5.0;
if(init_method == "ui" || init_method == "0")
{
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, next_seed()}(q_host);
@@ -640,6 +646,23 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::FillTrigValue<VDataType>{}(vnew_host);
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
}
else if(init_method == "3")
{
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
float bias_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<BiasDataType>::max());
ck_tile::FillUniformDistribution<QDataType>{-q_dtype_max, q_dtype_max, next_seed()}(q_host);
ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, next_seed()}(k_host);
ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, next_seed()}(
knew_host);
ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, next_seed()}(v_host);
ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, next_seed()}(
vnew_host);
ck_tile::FillUniformDistribution<BiasDataType>{
-bias_dtype_max, bias_dtype_max, next_seed()}(bias_host);
}
if(bias.type == bias_enum::alibi)
{
auto slopes = ck_tile::get_alibi_slopes<SaccDataType>(nhead);
@@ -658,6 +681,18 @@ fwd_result fmha_fwd_run(mode_enum mode,
}
}
}
if(qscale.type == quant_scale_enum::pertensor)
{
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
float qkv_max = 3.f;
q_descale_host(0) = qkv_max / q_dtype_max;
k_descale_host(0) = qkv_max / k_dtype_max;
v_descale_host(0) = qkv_max / v_dtype_max;
}
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine);
@@ -667,6 +702,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem q_descale_buf(q_descale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_descale_buf(k_descale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_descale_buf(v_descale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
@@ -702,81 +740,15 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes());
float scale_p = 1.f;
float scale_o = 1.f;
if(squant)
{
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
float p_dtype_max = v_dtype_max; // assume p and v is the same type
// Q tensor
{
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::min());
q_host.ForEach([&](auto& self, auto idx) {
float val = ck_tile::type_convert<float>(self(idx));
if(val > max_value)
max_value = val;
});
float scale = q_dtype_max / max_value;
q_host.ForEach([&](auto& self, auto idx) {
float val = ck_tile::type_convert<float>(self(idx));
self(idx) = ck_tile::type_convert<QDataType>(val * scale);
});
scale_s = scale_s / scale;
}
// K tensor
{
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::min());
k_host.ForEach([&](auto& self, auto idx) {
float val = ck_tile::type_convert<float>(self(idx));
if(val > max_value)
max_value = val;
});
float scale = k_dtype_max / max_value;
k_host.ForEach([&](auto& self, auto idx) {
float val = ck_tile::type_convert<float>(self(idx));
self(idx) = ck_tile::type_convert<KDataType>(val * scale);
});
scale_s = scale_s / scale;
}
// V tensor
{
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::min());
v_host.ForEach([&](auto& self, auto idx) {
float val = ck_tile::type_convert<float>(self(idx));
if(val > max_value)
max_value = val;
});
float scale = k_dtype_max / max_value;
v_host.ForEach([&](auto& self, auto idx) {
float val = ck_tile::type_convert<float>(self(idx));
self(idx) = ck_tile::type_convert<VDataType>(val * scale);
});
scale_o = (1.0 / p_dtype_max) / scale;
}
scale_p = p_dtype_max;
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8>)
{
float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
scale_o = scale_o * o_dtype_max / max_o;
}
}
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
v_buf.ToDevice(v_host.data());
knew_buf.ToDevice(knew_host.data());
vnew_buf.ToDevice(vnew_host.data());
bias_buf.ToDevice(bias_host.data());
q_descale_buf.ToDevice(q_descale_host.data());
k_descale_buf.ToDevice(k_descale_host.data());
v_descale_buf.ToDevice(v_descale_host.data());
seqstart_q.ToDevice(seqstart_q_host.data());
// Keep logical starts in seqstart_k; pass padded K via separate pointer
seqstart_k.ToDevice(seqstart_k_host.data());
@@ -816,7 +788,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
<< (seqlen_kpads[0] < 0 ? ""
: (std::string("(") + std::to_string(seqlen_kpads[0]) + ")"))
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
<< ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant
<< ", p_drop:" << p_drop << ", lse:" << lse << ", qscale:" << qscale
<< ", mask:" << mask << ", v:" << (is_v_rowmajor ? "r" : "c");
#if CK_TILE_FMHA_FWD_APPENDKV_API
if(0 < rotary_dim)
@@ -908,11 +880,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
traits.mask_type = mask.type;
traits.bias_type = bias.type;
traits.has_lse = lse;
traits.do_fp8_static_quant = squant;
if constexpr(std::is_same_v<fmha_fwd_traits, std::decay_t<decltype(traits)>>)
{
traits.has_dropout = (p_drop > 0.0f);
traits.qscale_type = qscale.type;
}
else if constexpr(std::is_same_v<fmha_fwd_pagedkv_traits,
std::decay_t<decltype(traits)>>)
@@ -1055,8 +1027,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.max_seqlen_q = max_seqlen_q;
args.scale_s = scale_s;
args.scale_p = scale_p;
args.scale_o = scale_o;
args.logits_soft_cap = logits_soft_cap;
@@ -1076,6 +1046,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
{
args.q_descale_ptr = q_descale_buf.GetDeviceBuffer();
args.k_descale_ptr = k_descale_buf.GetDeviceBuffer();
args.v_descale_ptr = v_descale_buf.GetDeviceBuffer();
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
args.stride_randval = stride_randval;
@@ -1351,23 +1325,34 @@ fwd_result fmha_fwd_run(mode_enum mode,
lse_buf.FromDevice(lse_host.data());
randval_buf.FromDevice(randval_host.data());
constexpr bool supports_squant = std::is_same_v<DataTypeConfig, FmhaFwdFp8> ||
constexpr bool supports_qscale = std::is_same_v<DataTypeConfig, FmhaFwdFp8> ||
std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16> ||
std::is_same_v<DataTypeConfig, FmhaFwdFp8Fp32>;
float scale_s_host = scale_s;
float scale_p_host = 1.0f;
float scale_o_host = 1.0f;
if(qscale.type == quant_scale_enum::pertensor)
{
scale_s_host = scale_s * q_descale_host(0) * k_descale_host(0);
scale_p_host = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
scale_o_host = v_descale_host(0) / scale_p_host;
}
auto p_compute_element_func = [&]() {
if constexpr(supports_squant)
return ck_tile::scales{scale_p};
if constexpr(supports_qscale)
return ck_tile::scales{scale_p_host};
else
return ck_tile::identity{};
}();
auto oacc_element_func = [&]() {
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t> && supports_squant)
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t> && supports_qscale)
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{scale_o});
else if constexpr(supports_squant)
return ck_tile::scales{scale_o};
ck_tile::scales{scale_o_host});
else if constexpr(supports_qscale)
return ck_tile::scales{scale_o_host};
else
return ck_tile::identity{};
}();
@@ -1573,7 +1558,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
s_host_ref,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales(scale_s));
ck_tile::scales(scale_s_host));
if(0.f < logits_soft_cap)
{
@@ -1818,7 +1803,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
scale_s,
p_drop,
lse,
squant,
qscale.type == quant_scale_enum::no_scale ? "no_scale"
: "pertensor",
bias.type == bias_enum::elementwise_bias
? "elementwise_bias"
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),