add per_token_head quantization to fmha_fwd

This commit is contained in:
msaffari-amd
2026-05-27 11:32:37 +00:00
parent ee3ada6e4a
commit b5cd209196
3 changed files with 286 additions and 14 deletions

View File

@@ -1062,6 +1062,20 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
else:
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
# PER_TOKEN_HEAD: FP8 fine-grained Q/K per-token-per-head + V per-head.
# Currently only wired for fp8bf16 + qr_async + hdim=128 on gfx9 (MI300/MI308).
# The non-paged fmha_fwd qr_async pipeline gained PER_TOKEN_HEAD support in
# block_fmha_pipeline_qr_ks_vs_async.hpp; keep this scope tight until we
# also extend qr / qr_async_trload / V3 pipelines.
if dtype in cls._DT_FP8BF16 and hdim == 128:
for logits, mask, sink in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
["f", "t"],
):
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, "no", "f", "f", "per_token_head", mask, "f", "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, "no", "f", "f", "per_token_head", mask, "f", "f", sink)) # fmt: skip
return pipelines

View File

@@ -305,6 +305,41 @@ struct FmhaFwdKernel
const int32_t* seqstart_v_scale_ptr;
};
// PER_TOKEN_HEAD descale layout (non-paged fmha_fwd):
// q_descale: [B, total_q, nhead_q] fp32 (batch+token row stride, head col stride)
// k_descale: [B, total_k, nhead_k] fp32
// v_descale: [nhead_k] fp32 (per-head only; no batch / token dim)
// For varlen (group) mode the batch dim is collapsed into total_q / total_k and the
// per-batch offset is derived from cu_seqlens * stride_*_descale instead of
// batch_stride_*_descale.
struct FmhaFwdCommonPerTokenHeadKargs : FmhaFwdCommonQScaleKargs
{
// Per-token (row) strides; V is per-head only, so stride_v_descale is unused.
ck_tile::index_t stride_q_descale;
ck_tile::index_t stride_k_descale;
ck_tile::index_t nhead_stride_q_descale;
ck_tile::index_t nhead_stride_k_descale;
ck_tile::index_t nhead_stride_v_descale;
// Unused under PER_TOKEN_HEAD but the qr_async pipeline takes it positionally;
// keep it as a no-op field rather than baking a constant into the call site.
ck_tile::index_t block_scale_size_kv = 128;
};
struct FmhaFwdBatchPerTokenHeadKargs : FmhaFwdCommonPerTokenHeadKargs
{
ck_tile::index_t batch_stride_q_descale;
ck_tile::index_t batch_stride_k_descale;
ck_tile::index_t batch_stride_v_descale; // callers should set this to 0 (V is per-head)
};
struct FmhaFwdGroupPerTokenHeadKargs : FmhaFwdCommonPerTokenHeadKargs
{
// group mode resolves per-batch offsets from cu_seqlens * stride_*_descale,
// so no dedicated batch_stride_*_descale is needed here.
};
struct FmhaFwdCommonLSEKargs
{
void* lse_ptr = nullptr;
@@ -383,11 +418,16 @@ struct FmhaFwdKernel
std::conditional_t<
QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
FmhaFwdBatchBlockScaleKargs,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::MX,
FmhaFwdBatchMXKargs,
FmhaFwdEmptyKargs<3>>>>,
std::conditional_t<
QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
FmhaFwdBatchBlockScaleKargs,
std::conditional_t<
QScaleEnum == BlockAttentionQuantScaleEnum::MX,
FmhaFwdBatchMXKargs,
std::conditional_t<QScaleEnum ==
BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD,
FmhaFwdBatchPerTokenHeadKargs,
FmhaFwdEmptyKargs<3>>>>>,
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
{
@@ -414,11 +454,16 @@ struct FmhaFwdKernel
std::conditional_t<
QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
FmhaFwdGroupBlockScaleKargs,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::MX,
FmhaFwdGroupMXKargs,
FmhaFwdEmptyKargs<3>>>>,
std::conditional_t<
QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
FmhaFwdGroupBlockScaleKargs,
std::conditional_t<
QScaleEnum == BlockAttentionQuantScaleEnum::MX,
FmhaFwdGroupMXKargs,
std::conditional_t<QScaleEnum ==
BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD,
FmhaFwdGroupPerTokenHeadKargs,
FmhaFwdEmptyKargs<3>>>>>,
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
@@ -612,6 +657,28 @@ struct FmhaFwdKernel
kargs.batch_stride_k_descale = batch_stride_k_descale;
kargs.batch_stride_v_descale = batch_stride_v_descale;
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
// PER_TOKEN_HEAD (non-paged fmha_fwd) descale layout:
// q_descale: [total_q, nhead_q] fp32 (per-token, per-head)
// k_descale: [total_k, nhead_k] fp32 (per-token, per-head)
// v_descale: [nhead_k] fp32 (per-head only)
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
// per-token row strides (V is per-head scalar so stride_v_descale is unused).
kargs.stride_q_descale = stride_q_descale;
kargs.stride_k_descale = stride_k_descale;
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
kargs.batch_stride_q_descale = batch_stride_q_descale;
kargs.batch_stride_k_descale = batch_stride_k_descale;
kargs.batch_stride_v_descale = batch_stride_v_descale;
}
if constexpr(kHasDropout)
{
if(drop_seed_offset.index() == 0) // seed & offset come from host
@@ -1065,6 +1132,25 @@ struct FmhaFwdKernel
kargs.seqstart_v_scale_ptr = reinterpret_cast<const int32_t*>(seqstart_v_scale_ptr);
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
// PER_TOKEN_HEAD (non-paged fmha_fwd, group mode) descale layout:
// q_descale: [total_q, nhead_q] fp32 (per-token, per-head)
// k_descale: [total_k, nhead_k] fp32 (per-token, per-head)
// v_descale: [nhead_k] fp32 (per-head only)
// Per-batch offsets are derived from query_start/key_start in the kernel,
// so we don't need batch_stride_*_descale here (group mode).
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
kargs.stride_q_descale = stride_q_descale;
kargs.stride_k_descale = stride_k_descale;
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
}
if constexpr(kHasDropout)
{
if(drop_seed_offset.index() == 0) // seed & offset come from host
@@ -1574,6 +1660,16 @@ struct FmhaFwdKernel
batch_offset_k_descale = key_start * kargs.stride_k_descale;
batch_offset_v_descale = kargs.seqstart_v_scale_ptr[i_batch];
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
// PER_TOKEN_HEAD descales: Q/K per-token-per-head, V per-head only.
// q_descale: [total_q, nhead_q] -> per-batch offset = query_start * row_stride
// k_descale: [total_k, nhead_k] -> per-batch offset = key_start * row_stride
// v_descale: [nhead_k] -> no per-batch offset (shared across batches)
batch_offset_q_descale = query_start * kargs.stride_q_descale;
batch_offset_k_descale = key_start * kargs.stride_k_descale;
batch_offset_v_descale = 0;
}
batch_offset_o = query_start * kargs.stride_o;
// real logical lengths (exclude PAD)
@@ -1642,8 +1738,12 @@ struct FmhaFwdKernel
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE ||
QScaleEnum == BlockAttentionQuantScaleEnum::MX)
QScaleEnum == BlockAttentionQuantScaleEnum::MX ||
QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
// For PER_TOKEN_HEAD V is per-head only (no per-batch dimension);
// callers should set kargs.batch_stride_v_descale = 0 in that case
// so this expression evaluates to 0 here.
batch_offset_q_descale =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_q_descale;
batch_offset_k_descale =
@@ -2121,6 +2221,66 @@ struct FmhaFwdKernel
make_null_tile_window(make_tuple()),
sink_value);
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
// PER_TOKEN_HEAD: Q/K per-(token, head), V per-head (FP8 fine-grained).
// Resolve descale base ptrs to (batch, head); the pipeline indexes
// per-token offsets within head using stride_q/k_descale.
//
// Layout convention (non-paged fmha_fwd):
// q_descale: [total_q, nhead_q] fp32
// k_descale: [total_k, nhead_k] fp32
// v_descale: [nhead_k] fp32
const float* q_descale_ptr =
reinterpret_cast<const float*>(kargs.q_descale_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_q_descale +
batch_offset_q_descale;
const float* k_descale_ptr =
reinterpret_cast<const float*>(kargs.k_descale_ptr) +
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
kargs.nhead_stride_k_descale +
batch_offset_k_descale;
const float* v_descale_ptr =
reinterpret_cast<const float*>(kargs.v_descale_ptr) +
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
kargs.nhead_stride_v_descale +
batch_offset_v_descale;
return FmhaPipeline{}(
q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func - PER_TOKEN_HEAD applies its own per-(row,col) scale
identity{}, // p_compute_element_func
identity{}, // o_acc_element_func - V descale is folded in via 'o_acc += o_acc0 * v_descale'
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout,
k_descale_ptr,
v_descale_ptr,
kargs.block_scale_size_kv, // unused for PER_TOKEN_HEAD
make_null_tile_window(make_tuple()),
make_null_tile_window(make_tuple()),
make_null_tile_window(make_tuple()),
sink_value,
// PER_TOKEN_HEAD-only:
q_descale_ptr,
kargs.stride_q_descale,
kargs.stride_k_descale);
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX)
{
using QScaleDataType = typename FmhaPipeline::QScaleDataType;

View File

@@ -204,7 +204,16 @@ struct BlockFmhaPipelineQRKSVSAsync
const QScaleDramBlockWindowTmp&, // M0*(K0/kQKScaleGranularity) tile
const KScaleDramBlockWindowTmp&, // N0*(K0/kQKScaleGranularity) tile
const VScaleDramBlockWindowTmp&, // N1*(K1/kVScaleGranularity) tile
const float sink_v) const
const float sink_v,
// PER_TOKEN_HEAD: per-(token, head) Q/K descales; V is per-head scalar.
// q_descale_ptr / k_descale_ptr / v_descale_ptr are already
// (batch, head)-resolved by the kernel; pipeline indexes
// q_descale_ptr[(q_origin + i) * stride_q_descale_token]
// k_descale_ptr[(k_origin + j) * stride_k_descale_token]
// v_descale = *v_descale_ptr (loaded once per loop iter)
const float* q_descale_ptr = nullptr,
const index_t stride_q_descale_token = 0,
const index_t stride_k_descale_token = 0) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -463,6 +472,87 @@ struct BlockFmhaPipelineQRKSVSAsync
sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
}
__builtin_amdgcn_sched_barrier(1);
// PER_TOKEN_HEAD: dequantize QK with per-row Q descale and per-col K descale.
// s_acc[i,j] *= q_descale[(q_origin + i) * stride_q_descale_token]
// * k_descale[(k_origin + j) * stride_k_descale_token]
// q_descale_ptr / k_descale_ptr are already (batch, head)-resolved by the kernel.
//
// qr_async (no trload) on gfx9 is very tight on SGPRs; folding both row + col
// global loads into a single 2-D sweep over s_acc made the compiler spill the
// K/V buffer-SRDs to VGPRs and produce invalid asm. We instead split the
// dequant into two 1-D sweeps so each pass keeps a single descale pointer
// live and reuses the existing tile-distribution machinery without inflating
// SGPR pressure inside the K/V async-load region.
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
// PER_TOKEN_HEAD dequant: s_acc[i,j] *= q_descale[i] * k_descale[j].
//
// qr_async on gfx9 is extremely tight on SGPRs (the K/V buffer-SRDs
// must stay live across iterations of the main K-loop). When threads
// issue scalar global loads from k_descale_ptr inside the per-tile
// 2-D sweep, the compiler runs out of SGPR budget and spills the K
// SRD into VGPRs, producing invalid `buffer_load_dword v_dst, v[..]
// ... lds` asm on gfx942.
//
// To break that dependency we stage both Q-row and K-col descales
// through LDS first (one warp-strided load per descale tile), then
// the per-element sweep is a pure LDS read + FP multiply. This keeps
// the K SRD live-range contained inside the QK gemm, where the
// existing schedule already accounts for it.
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
const index_t q_row_base = q_origin.at(number<0>{});
// Mirror the V / randval path's sink-aware col-base derivation so we
// never touch k_dram_block_window.get_window_origin() inside this
// SGPR-tight region (that read leaks the K window state and was
// independently observed to push the SRD into VGPRs).
const bool in_sink_phase = (num_sink_loop > i_total_loops);
const index_t k_col_base =
in_sink_phase
? (kv_load_start + i_total_loops * kN0)
: (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0);
__builtin_amdgcn_sched_barrier(0);
// LDS staging tiles. Allocated per-block; kept small (kM0 + kN0 fp32).
__shared__ float lds_q_descale[kM0];
__shared__ float lds_k_descale[kN0];
const index_t tid_in_block =
static_cast<index_t>(threadIdx.x + threadIdx.y * blockDim.x +
threadIdx.z * blockDim.x * blockDim.y);
const index_t threads_per_block =
static_cast<index_t>(blockDim.x * blockDim.y * blockDim.z);
// Q-row descales (kM0 entries).
for(index_t off = tid_in_block; off < kM0; off += threads_per_block)
{
lds_q_descale[off] =
q_descale_ptr[(q_row_base + off) * stride_q_descale_token];
}
// K-col descales (kN0 entries).
for(index_t off = tid_in_block; off < kN0; off += threads_per_block)
{
lds_k_descale[off] =
k_descale_ptr[(k_col_base + off) * stride_k_descale_token];
}
__builtin_amdgcn_s_barrier();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const index_t i = tile_idx.at(number<0>{});
const index_t j = tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= lds_q_descale[i] * lds_k_descale[j];
});
});
__builtin_amdgcn_sched_barrier(0);
}
// dequant
auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
@@ -785,12 +875,19 @@ struct BlockFmhaPipelineQRKSVSAsync
const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv;
v_descale = v_descale_ptr[kv_idx];
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
// PER_TOKEN_HEAD: V scale is per-head only; v_descale_ptr is already
// (batch, head)-resolved by the kernel, so load the single scalar.
v_descale = *v_descale_ptr;
}
// STAGE 3, KV gemm
auto o_acc0 = decltype(o_acc){};
clear_tile(o_acc0);
auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE ||
QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
return o_acc0;
}
@@ -879,7 +976,8 @@ struct BlockFmhaPipelineQRKSVSAsync
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE ||
QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
tile_elementwise_inout(
[&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0);