Revert "Revert " Fp8 block scale quantization for fmha fwd (#3330)" (#3633)" (#3635)

This reverts commit de5a1d730d.

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
ltqin
2026-01-24 01:03:22 +08:00
committed by GitHub
parent 2e08a7e5ab
commit 67f0b74ec6
14 changed files with 667 additions and 84 deletions

View File

@@ -77,11 +77,13 @@ def get_mask_cpp_check_expr(mask: str) -> str:
QSCALE_MAP = {
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
}
QSCALE_CHECK_MAP = {
"no": "quant_scale_enum::no_scale",
"pertensor": "quant_scale_enum::pertensor",
"blockscale": "quant_scale_enum::blockscale",
}
BIAS_MAP = {

View File

@@ -1024,7 +1024,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
# no need lse/dropout kernels
for logits, qscale, mask, bias, sink in itertools.product(
["t", "f"],
["no", "pertensor"],
["no", "pertensor", "blockscale"],
get_mask_map(mask_impl).keys(),
["no"],
["f", "t"],
@@ -1152,7 +1152,10 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32:
# no need lse/dropout kernels
for logits, qscale, mask, bias in itertools.product(
["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"]
["f"],
["no", "pertensor", "blockscale"],
get_mask_map(mask_impl).keys(),
["no"],
):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip

View File

@@ -230,6 +230,8 @@ struct fmha_fwd_args
// array [batch + 1]. (Used with padding)
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
const void* block_scale_seqstart_q_ptr;
const void* block_scale_seqstart_k_ptr;
const void* sink_ptr;
ck_tile::index_t seqlen_q;
@@ -257,6 +259,9 @@ struct fmha_fwd_args
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t nhead_stride_q_descale;
ck_tile::index_t nhead_stride_k_descale;
ck_tile::index_t nhead_stride_v_descale;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
@@ -264,6 +269,9 @@ struct fmha_fwd_args
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_q_descale;
ck_tile::index_t batch_stride_k_descale;
ck_tile::index_t batch_stride_v_descale;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
@@ -276,6 +284,9 @@ struct fmha_fwd_args
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset;
ck_tile::index_t block_scale_size_q;
ck_tile::index_t block_scale_size_kv;
};
struct fmha_fwd_pagedkv_args
@@ -615,6 +626,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.block_scale_seqstart_q_ptr,
args.block_scale_seqstart_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
@@ -634,6 +647,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.nhead_stride_q_descale,
args.nhead_stride_k_descale,
args.nhead_stride_v_descale,
args.window_size_left,
args.window_size_right,
args.sink_size,
@@ -642,6 +658,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.block_scale_size_q,
args.block_scale_size_kv,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr,
args.sink_ptr);
@@ -679,6 +697,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.nhead_stride_q_descale,
args.nhead_stride_k_descale,
args.nhead_stride_v_descale,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
@@ -686,6 +707,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.batch_stride_randval,
args.batch_stride_lse,
args.batch_stride_o,
args.batch_stride_q_descale,
args.batch_stride_k_descale,
args.batch_stride_v_descale,
args.window_size_left,
args.window_size_right,
args.sink_size,
@@ -693,6 +717,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.block_scale_size_q,
args.block_scale_size_kv,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr,
args.sink_ptr);

View File

@@ -210,6 +210,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
const ck_tile::stream_config& stream_config,
std::optional<std::string> json = std::nullopt)
{
// Note: block_scale_size_q_ and block_scale_size_kv_ should be greater than or equal to the
// compute block size
constexpr ck_tile::index_t block_scale_size_q_ = 128;
constexpr ck_tile::index_t block_scale_size_kv_ = 128;
const std::string data_type = []() {
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp32>)
return "fp32";
@@ -471,7 +476,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
std::size_t flop = 0, num_byte = 0;
auto max_seqlen_q =
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
size_t i_block_scale_q = 0;
size_t i_block_scale_k = 0;
std::vector<int32_t> block_scale_seqstart_q_host = {0};
std::vector<int32_t> block_scale_seqstart_k_host = {0};
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
{
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
@@ -487,6 +496,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
{
max_seqlen_k = real_seqlen_k;
}
i_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_size_q_);
i_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_size_kv_);
block_scale_seqstart_q_host.push_back(i_block_scale_q);
block_scale_seqstart_k_host.push_back(i_block_scale_k);
flop += nhead * (static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_q +
static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_v);
@@ -548,6 +561,15 @@ fwd_result fmha_fwd_run(mode_enum mode,
? seqstart_k_with_padding_host.back()
: seqstart_k_host.back()));
const ck_tile::index_t num_block_scale_q =
(mode == mode_enum::batch)
? ck_tile::integer_divide_ceil(shape_seqlen_q, block_scale_size_q_)
: i_block_scale_q;
const ck_tile::index_t num_block_scale_kv =
(mode == mode_enum::batch)
? ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_size_kv_)
: i_block_scale_k;
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<SMPLComputeDataType> sink_host({nhead});
@@ -599,9 +621,18 @@ fwd_result fmha_fwd_run(mode_enum mode,
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
// TODO - change the tensor length for different quant scale
ck_tile::HostTensor<float> q_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
ck_tile::HostTensor<float> k_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
ck_tile::HostTensor<float> v_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
ck_tile::HostTensor<float> q_descale_host(
qscale.type == quant_scale_enum::blockscale
? std::array<ck_tile::index_t, 3>{shape_batch, nhead, num_block_scale_q}
: std::array<ck_tile::index_t, 3>{1, 1, 1});
ck_tile::HostTensor<float> k_descale_host(
qscale.type == quant_scale_enum::blockscale
? std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_kv}
: std::array<ck_tile::index_t, 3>{1, 1, 1});
ck_tile::HostTensor<float> v_descale_host(
qscale.type == quant_scale_enum::blockscale
? std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_kv}
: std::array<ck_tile::index_t, 3>{1, 1, 1});
// batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
@@ -717,6 +748,12 @@ fwd_result fmha_fwd_run(mode_enum mode,
k_descale_host(0) = qkv_max / k_dtype_max;
v_descale_host(0) = qkv_max / v_dtype_max;
}
else if(qscale.type == quant_scale_enum::blockscale)
{
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(q_descale_host);
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(k_descale_host);
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(v_descale_host);
}
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine);
@@ -737,6 +774,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::DeviceMem q_descale_buf(q_descale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_descale_buf(k_descale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_descale_buf(v_descale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem block_scale_seqstart_q_buf(block_scale_seqstart_q_host.size() *
sizeof(int32_t));
ck_tile::DeviceMem block_scale_seqstart_k_buf(block_scale_seqstart_k_host.size() *
sizeof(int32_t));
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
@@ -782,6 +823,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
q_descale_buf.ToDevice(q_descale_host.data());
k_descale_buf.ToDevice(k_descale_host.data());
v_descale_buf.ToDevice(v_descale_host.data());
block_scale_seqstart_q_buf.ToDevice(block_scale_seqstart_q_host.data());
block_scale_seqstart_k_buf.ToDevice(block_scale_seqstart_k_host.data());
seqstart_q.ToDevice(seqstart_q_host.data());
// Keep logical starts in seqstart_k; pass padded K via separate pointer
seqstart_k.ToDevice(seqstart_k_host.data());
@@ -975,11 +1018,14 @@ fwd_result fmha_fwd_run(mode_enum mode,
}();
const ck_tile::index_t nhead_stride_bias =
(i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_q_descale = num_block_scale_q;
const ck_tile::index_t nhead_stride_k_descale = num_block_scale_kv;
const ck_tile::index_t nhead_stride_v_descale = num_block_scale_kv;
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k =
@@ -997,6 +1043,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
const ck_tile::index_t batch_stride_q_descale = num_block_scale_q * nhead;
const ck_tile::index_t batch_stride_k_descale = num_block_scale_kv * nhead_k;
const ck_tile::index_t batch_stride_v_descale = num_block_scale_kv * nhead_k;
// setup split_stride_* arguments (only used in split-kv kernel)
const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v);
@@ -1084,9 +1133,39 @@ fwd_result fmha_fwd_run(mode_enum mode,
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
{
args.q_descale_ptr = q_descale_buf.GetDeviceBuffer();
args.k_descale_ptr = k_descale_buf.GetDeviceBuffer();
args.v_descale_ptr = v_descale_buf.GetDeviceBuffer();
if(qscale.type == quant_scale_enum::blockscale)
{
args.q_descale_ptr =
reinterpret_cast<const float*>(q_descale_buf.GetDeviceBuffer());
args.k_descale_ptr =
reinterpret_cast<const float*>(k_descale_buf.GetDeviceBuffer());
args.v_descale_ptr =
reinterpret_cast<const float*>(v_descale_buf.GetDeviceBuffer());
args.block_scale_seqstart_q_ptr =
(mode == mode_enum::group ? block_scale_seqstart_q_buf.GetDeviceBuffer()
: nullptr);
args.block_scale_seqstart_k_ptr =
(mode == mode_enum::group ? block_scale_seqstart_k_buf.GetDeviceBuffer()
: nullptr);
args.nhead_stride_q_descale = nhead_stride_q_descale;
args.nhead_stride_k_descale = nhead_stride_k_descale;
args.nhead_stride_v_descale = nhead_stride_v_descale;
args.batch_stride_q_descale = batch_stride_q_descale;
args.batch_stride_k_descale = batch_stride_k_descale;
args.batch_stride_v_descale = batch_stride_v_descale;
args.block_scale_size_q = block_scale_size_q_;
args.block_scale_size_kv = block_scale_size_kv_;
}
else
{
args.q_descale_ptr = q_descale_buf.GetDeviceBuffer();
args.k_descale_ptr = k_descale_buf.GetDeviceBuffer();
args.v_descale_ptr = v_descale_buf.GetDeviceBuffer();
}
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
@@ -1589,14 +1668,42 @@ fwd_result fmha_fwd_run(mode_enum mode,
#endif
// reference
ck_tile::
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
if(qscale.type == quant_scale_enum::blockscale)
{
const ck_tile::index_t q_offset =
(mode == mode_enum::batch) ? 0 : block_scale_seqstart_q_host[wb];
const ck_tile::index_t k_offset =
(mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb];
ck_tile::reference_batched_quant_gemm<QDataType,
KDataType,
SaccDataType,
SMPLComputeDataType>(
q_host_ref,
k_host_ref,
s_host_ref,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales(scale_s_host));
ck_tile::idx_identity{},
ck_tile::idx_identity{},
[&](auto idx, auto value) {
return value * scale_s *
q_descale_host(b_idx,
std::get<0>(idx),
q_offset + std::get<1>(idx) / block_scale_size_q_) *
k_descale_host(b_idx,
std::get<0>(idx) / nr,
k_offset + std::get<2>(idx) / block_scale_size_kv_);
});
}
else
{
ck_tile::
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
q_host_ref,
k_host_ref,
s_host_ref,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales(scale_s_host));
}
if(0.f < logits_soft_cap)
{
@@ -1794,13 +1901,35 @@ fwd_result fmha_fwd_run(mode_enum mode,
}
}
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
p_host_ref,
v_host_ref,
o_host_ref,
ck_tile::identity{},
ck_tile::identity{},
oacc_element_func);
if(qscale.type == quant_scale_enum::blockscale)
{
const ck_tile::index_t v_offset =
(mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb];
ck_tile::
reference_batched_quant_gemm<PDataType, VDataType, OaccDataType, ODataType>(
p_host_ref,
v_host_ref,
o_host_ref,
ck_tile::idx_identity{},
[&](auto idx, auto value) {
return ck_tile::type_convert<float>(value) *
v_descale_host(b_idx,
std::get<0>(idx) / nr,
v_offset +
std::get<2>(idx) / block_scale_size_kv_);
},
ck_tile::idx_identity{});
}
else
{
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
p_host_ref,
v_host_ref,
o_host_ref,
ck_tile::identity{},
ck_tile::identity{},
oacc_element_func);
}
ck_tile::HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v});
// clang-format off
@@ -1808,7 +1937,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); });
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
// clang-format on
auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
bool cur_pass = ck_tile::check_err(o_host_result,
o_host_ref,
@@ -1866,31 +1994,33 @@ fwd_result fmha_fwd_run(mode_enum mode,
if(json)
{
dump_fmha_fwd_json_results(*json,
data_type,
mode == mode_enum::batch ? "batch" : "group",
io_layout(i_perm, o_perm),
batch,
nhead,
nhead_k,
seqlen_qs[0],
seqlen_ks[0],
seqlen_kpads[0],
hdim_q,
hdim_v,
scale_s,
p_drop,
lse,
qscale.type == quant_scale_enum::no_scale ? "no_scale"
: "pertensor",
bias.type == bias_enum::elementwise_bias
? "elementwise_bias"
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),
is_v_rowmajor ? "r" : "c",
pass,
ave_time,
tflops,
gb_per_sec);
dump_fmha_fwd_json_results(
*json,
data_type,
mode == mode_enum::batch ? "batch" : "group",
io_layout(i_perm, o_perm),
batch,
nhead,
nhead_k,
seqlen_qs[0],
seqlen_ks[0],
seqlen_kpads[0],
hdim_q,
hdim_v,
scale_s,
p_drop,
lse,
qscale.type == quant_scale_enum::no_scale
? "no_scale"
: (qscale.type == quant_scale_enum::pertensor ? "pertensor" : "blockscale"),
bias.type == bias_enum::elementwise_bias
? "elementwise_bias"
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),
is_v_rowmajor ? "r" : "c",
pass,
ave_time,
tflops,
gb_per_sec);
}
return pass ? fwd_result::success : fwd_result::failure;

View File

@@ -13,6 +13,7 @@ enum class quant_scale_enum
{
no_scale = 0,
pertensor = 1,
blockscale,
};
struct quant_scale_info
@@ -25,6 +26,8 @@ struct quant_scale_info
os << "n";
else if(type == quant_scale_enum::pertensor)
os << "pt";
else if(type == quant_scale_enum::blockscale)
os << "bs";
}
static quant_scale_info decode(std::string str)
@@ -38,6 +41,10 @@ struct quant_scale_info
{
info.type = quant_scale_enum::pertensor;
}
else if(str == "bs" || str == "2")
{
info.type = quant_scale_enum::blockscale;
}
else
{
throw std::invalid_argument("invalid quant scale value: " + str);

View File

@@ -95,10 +95,11 @@ run_fp8bf16_tests() {
for perm in 0 1 ; do
for b in 1 2 ; do
for hdim in 64 128 256 ; do
for scale in 1 2; do
$EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=1 -kname=$KNAME $COMMON_ARGS
$EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=$scale -kname=$KNAME $COMMON_ARGS
done ; done ; done
done ; done ; done ; done
}
run_fp8fp32_tests() {