Revert "Revert " Fp8 block scale quantization for fmha fwd (#3330)" (#3633)" (#3635)

This reverts commit de5a1d730d.

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
ltqin
2026-01-24 01:03:22 +08:00
committed by GitHub
parent 2e08a7e5ab
commit 67f0b74ec6
14 changed files with 667 additions and 84 deletions

View File

@@ -168,6 +168,29 @@ struct FmhaFwdKernel
const void* v_descale_ptr = nullptr;
};
struct FmhaFwdCommonBlockScaleKargs : public FmhaFwdCommonQScaleKargs
{
ck_tile::index_t nhead_stride_q_descale;
ck_tile::index_t nhead_stride_k_descale;
ck_tile::index_t nhead_stride_v_descale;
ck_tile::index_t block_scale_size_q;
ck_tile::index_t block_scale_size_kv;
};
struct FmhaFwdBatchBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs
{
ck_tile::index_t batch_stride_q_descale;
ck_tile::index_t batch_stride_k_descale;
ck_tile::index_t batch_stride_v_descale;
};
struct FmhaFwdGroupBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs
{
const int32_t* block_scale_seqstart_q_ptr;
const int32_t* block_scale_seqstart_k_ptr;
};
struct FmhaFwdCommonLSEKargs
{
void* lse_ptr = nullptr;
@@ -243,9 +266,12 @@ struct FmhaFwdKernel
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<3>>,
std::conditional_t<
QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
FmhaFwdBatchBlockScaleKargs,
FmhaFwdEmptyKargs<3>>>,
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
{
@@ -269,9 +295,12 @@ struct FmhaFwdKernel
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<3>>,
std::conditional_t<
QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
FmhaFwdGroupBlockScaleKargs,
FmhaFwdEmptyKargs<3>>>,
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
@@ -328,6 +357,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_descale,
ck_tile::index_t nhead_stride_k_descale,
ck_tile::index_t nhead_stride_v_descale,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
@@ -335,6 +367,9 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t batch_stride_q_descale,
ck_tile::index_t batch_stride_k_descale,
ck_tile::index_t batch_stride_v_descale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
@@ -343,6 +378,8 @@ struct FmhaFwdKernel
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
ck_tile::index_t block_scale_size_q,
ck_tile::index_t block_scale_size_kv,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
@@ -413,6 +450,23 @@ struct FmhaFwdKernel
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
kargs.batch_stride_q_descale = batch_stride_q_descale;
kargs.batch_stride_k_descale = batch_stride_k_descale;
kargs.batch_stride_v_descale = batch_stride_v_descale;
kargs.block_scale_size_q = block_scale_size_q;
kargs.block_scale_size_kv = block_scale_size_kv;
}
if constexpr(kHasDropout)
{
if(drop_seed_offset.index() == 0) // seed & offset come from host
@@ -478,6 +532,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_descale,
ck_tile::index_t nhead_stride_k_descale,
ck_tile::index_t nhead_stride_v_descale,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
@@ -485,6 +542,9 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t batch_stride_q_descale,
ck_tile::index_t batch_stride_k_descale,
ck_tile::index_t batch_stride_v_descale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
@@ -492,6 +552,8 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
ck_tile::index_t block_scale_size_q,
ck_tile::index_t block_scale_size_kv,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
@@ -528,6 +590,9 @@ struct FmhaFwdKernel
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
nhead_stride_q_descale,
nhead_stride_k_descale,
nhead_stride_v_descale,
batch_stride_q,
batch_stride_k,
batch_stride_v,
@@ -535,6 +600,9 @@ struct FmhaFwdKernel
batch_stride_randval,
batch_stride_lse,
batch_stride_o,
batch_stride_q_descale,
batch_stride_k_descale,
batch_stride_v_descale,
window_size_left,
window_size_right,
sink_size,
@@ -542,6 +610,8 @@ struct FmhaFwdKernel
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
block_scale_size_q,
block_scale_size_kv,
cu_seqlen_q_ptr,
cu_seqlen_k_ptr,
sink_ptr);
@@ -581,6 +651,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_descale,
ck_tile::index_t nhead_stride_k_descale,
ck_tile::index_t nhead_stride_v_descale,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
@@ -588,6 +661,9 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t batch_stride_q_descale,
ck_tile::index_t batch_stride_k_descale,
ck_tile::index_t batch_stride_v_descale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
@@ -595,6 +671,8 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
const std::tuple<const void*, const void*>& drop_seed_offset,
ck_tile::index_t block_scale_size_q,
ck_tile::index_t block_scale_size_kv,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
@@ -631,6 +709,9 @@ struct FmhaFwdKernel
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
nhead_stride_q_descale,
nhead_stride_k_descale,
nhead_stride_v_descale,
batch_stride_q,
batch_stride_k,
batch_stride_v,
@@ -638,6 +719,9 @@ struct FmhaFwdKernel
batch_stride_randval,
batch_stride_lse,
batch_stride_o,
batch_stride_q_descale,
batch_stride_k_descale,
batch_stride_v_descale,
window_size_left,
window_size_right,
sink_size,
@@ -645,6 +729,8 @@ struct FmhaFwdKernel
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
block_scale_size_q,
block_scale_size_kv,
cu_seqlen_q_ptr,
cu_seqlen_k_ptr,
sink_ptr);
@@ -666,6 +752,8 @@ struct FmhaFwdKernel
const void* seqstart_k_ptr,
const void* seqlen_q_ptr,
const void* seqlen_k_ptr,
const void* block_scale_seqstart_q_ptr,
const void* block_scale_seqstart_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -685,6 +773,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_descale,
ck_tile::index_t nhead_stride_k_descale,
ck_tile::index_t nhead_stride_v_descale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
@@ -694,6 +785,8 @@ struct FmhaFwdKernel
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
ck_tile::index_t block_scale_size_q,
ck_tile::index_t block_scale_size_kv,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
@@ -763,6 +856,24 @@ struct FmhaFwdKernel
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
kargs.block_scale_size_q = block_scale_size_q;
kargs.block_scale_size_kv = block_scale_size_kv;
kargs.block_scale_seqstart_q_ptr =
reinterpret_cast<const int32_t*>(block_scale_seqstart_q_ptr);
kargs.block_scale_seqstart_k_ptr =
reinterpret_cast<const int32_t*>(block_scale_seqstart_k_ptr);
}
if constexpr(kHasDropout)
{
if(drop_seed_offset.index() == 0) // seed & offset come from host
@@ -814,6 +925,8 @@ struct FmhaFwdKernel
const void* seqstart_k_ptr,
const void* seqlen_q_ptr,
const void* seqlen_k_ptr,
const void* block_scale_seqstart_q_ptr,
const void* block_scale_seqstart_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -833,6 +946,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_descale,
ck_tile::index_t nhead_stride_k_descale,
ck_tile::index_t nhead_stride_v_descale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
@@ -841,6 +957,8 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
ck_tile::index_t block_scale_size_q,
ck_tile::index_t block_scale_size_kv,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
@@ -860,6 +978,8 @@ struct FmhaFwdKernel
seqstart_k_ptr,
seqlen_q_ptr,
seqlen_k_ptr,
block_scale_seqstart_q_ptr,
block_scale_seqstart_k_ptr,
hdim_q,
hdim_v,
num_head_q,
@@ -879,6 +999,9 @@ struct FmhaFwdKernel
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
nhead_stride_q_descale,
nhead_stride_k_descale,
nhead_stride_v_descale,
window_size_left,
window_size_right,
sink_size,
@@ -887,6 +1010,8 @@ struct FmhaFwdKernel
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
block_scale_size_q,
block_scale_size_kv,
cu_seqlen_q_ptr,
cu_seqlen_k_ptr,
sink_ptr);
@@ -909,6 +1034,8 @@ struct FmhaFwdKernel
const void* seqstart_k_ptr,
const void* seqlen_q_ptr,
const void* seqlen_k_ptr,
const void* block_scale_seqstart_q_ptr,
const void* block_scale_seqstart_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -928,6 +1055,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_q_descale,
ck_tile::index_t nhead_stride_k_descale,
ck_tile::index_t nhead_stride_v_descale,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
@@ -936,6 +1066,8 @@ struct FmhaFwdKernel
float p_drop,
bool s_randval,
const std::tuple<const void*, const void*>& drop_seed_offset,
ck_tile::index_t block_scale_size_q,
ck_tile::index_t block_scale_size_kv,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
@@ -955,6 +1087,8 @@ struct FmhaFwdKernel
seqstart_k_ptr,
seqlen_q_ptr,
seqlen_k_ptr,
block_scale_seqstart_q_ptr,
block_scale_seqstart_k_ptr,
hdim_q,
hdim_v,
num_head_q,
@@ -974,6 +1108,9 @@ struct FmhaFwdKernel
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
nhead_stride_q_descale,
nhead_stride_k_descale,
nhead_stride_v_descale,
window_size_left,
window_size_right,
sink_size,
@@ -982,6 +1119,8 @@ struct FmhaFwdKernel
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
block_scale_size_q,
block_scale_size_kv,
cu_seqlen_q_ptr,
cu_seqlen_k_ptr,
sink_ptr);
@@ -1111,13 +1250,16 @@ struct FmhaFwdKernel
const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
long_index_t batch_offset_q = 0;
long_index_t batch_offset_k = 0;
long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0;
long_index_t batch_offset_randval = 0;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
long_index_t batch_offset_q = 0;
long_index_t batch_offset_k = 0;
long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0;
long_index_t batch_offset_randval = 0;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
long_index_t batch_offset_q_descale = 0;
long_index_t batch_offset_k_descale = 0;
long_index_t batch_offset_v_descale = 0;
const float sink_value =
kargs.sink_ptr != nullptr
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
@@ -1153,6 +1295,14 @@ struct FmhaFwdKernel
{
batch_offset_randval = query_start * kargs.stride_randval;
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
const long_index_t bquery_start = kargs.block_scale_seqstart_q_ptr[i_batch];
const long_index_t bkey_start = kargs.block_scale_seqstart_k_ptr[i_batch];
batch_offset_q_descale = bquery_start;
batch_offset_k_descale = bkey_start;
batch_offset_v_descale = bkey_start;
}
batch_offset_o = query_start * kargs.stride_o;
// real logical lengths (exclude PAD)
@@ -1220,6 +1370,15 @@ struct FmhaFwdKernel
batch_offset_randval =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
batch_offset_q_descale =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_q_descale;
batch_offset_k_descale =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_k_descale;
batch_offset_v_descale =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_v_descale;
}
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
// If cumulative seqlen pointers are provided, override per-batch effective lengths
@@ -1540,7 +1699,8 @@ struct FmhaFwdKernel
}();
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
auto o_acc_tile = [&]() {
auto o_acc_tile = [&, i_nhead_ = i_nhead]() {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
// TODO - move global load of descale to pipeline
@@ -1581,8 +1741,62 @@ struct FmhaFwdKernel
block_indices,
smem_ptr,
dropout,
nullptr,
nullptr,
1,
sink_value);
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
const float* q_descale_ptr =
reinterpret_cast<const float*>(kargs.q_descale_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_q_descale +
batch_offset_q_descale;
const float* k_descale_ptr =
reinterpret_cast<const float*>(kargs.k_descale_ptr) +
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
kargs.nhead_stride_k_descale +
batch_offset_k_descale;
const float* v_descale_ptr =
reinterpret_cast<const float*>(kargs.v_descale_ptr) +
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
kargs.nhead_stride_v_descale +
batch_offset_v_descale;
size_t idx = i_m0 / kargs.block_scale_size_q;
float q_descale = q_descale_ptr[idx];
// BLOCKSCALE: P is scaled in exp2(x+shift) where shift=7 or 8
// Both P and rowsum are scaled by 2^shift, canceling in normalization
// No additional scaling needed in p_compute_element_func or o_acc_element_func
return FmhaPipeline{}(
q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
scales<float>(q_descale), // s_acc_element_func
identity{}, // p_compute_element_func - No scaling (done in exp2)
identity{}, // o_acc_element_func - No dequant needed (canceled by rowsum)
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout,
k_descale_ptr,
v_descale_ptr,
kargs.block_scale_size_kv,
sink_value);
}
else
{
return FmhaPipeline{}(q_dram_window,