UA Sink: add the sink parameter to the interfaces + update examples

This commit is contained in:
Damien Lejeune
2026-05-28 13:46:50 +00:00
parent 7ad5849d0e
commit e7cb485d98
4 changed files with 46 additions and 3 deletions

View File

@@ -653,6 +653,26 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
args.window_size_right = problem.mask.right;
args.is_top_left = (problem.mask.type == mask_enum::mask_top_left);
// Mirror the per-Q-head sink vector to device memory and hand the
// pointer to kargs. Empty `problem.sinks` leaves `args.sink_ptr` as
// its default `nullptr` — matches the classic no-sink convention.
// The buffer must outlive `ck_tile::unified_attention(args, ...)`
// below, so declare it in `run_impl`'s scope alongside seq_lens_buf
// / block_tables_buf. We use the `DeviceMem` default-ctor + Realloc
// idiom (same as the grouped-gemm examples) because `DeviceMem`
// owns a HIP allocation but defines no copy/move, so `sink_buf =
// DeviceMem(size)` would shallow-copy and double-free on scope
// exit. The device-side kernel does not dereference `sink_ptr` yet;
// this commit only verifies the pointer survives the round-trip
// through kargs without observable behaviour change.
ck_tile::DeviceMem sink_buf;
if(!problem.sinks.empty())
{
sink_buf.Realloc(problem.sinks.size() * sizeof(float));
sink_buf.ToDevice(problem.sinks.data());
args.sink_ptr = sink_buf.GetDeviceBuffer();
}
args.num_blks = problem.num_blks;
args.q_ptr = q_buf.GetDeviceBuffer();

View File

@@ -123,6 +123,18 @@ struct unified_attention_args
index_t split_stride_lse_acc = 0;
index_t nhead_stride_o_acc = 0;
index_t nhead_stride_lse_acc = 0;
// Learnable per-query-head attention sink (GPT-OSS / vLLM convention).
// When non-null, the device pointer addresses a contiguous
// `float[num_head_q]` of sink scalars; the softmax denominator gains
// one virtual key per Q head with logit `sink_ptr[h]` and an all-zero
// V row. Default `nullptr` reproduces the classic no-sink softmax.
// The host-facing CLI / aiter wrapper promotes bf16 sink tensors to
// fp32 host-side before pinning them on device — matches the
// `const void*` / `const float*` convention used by FMHA.
// Only the kargs plumbing reads the pointer today; the device-side
// pipeline does not yet consume it.
const void* sink_ptr = nullptr;
};
std::ostream& operator<<(std::ostream& stream,

View File

@@ -407,7 +407,8 @@ float unified_attention_kernel_launch(const unified_attention_args& args,
args.cache_ptr_int32_overflow_possible,
args.window_size_left,
args.window_size_right,
args.is_top_left);
args.is_top_left,
args.sink_ptr);
dim3 grids;
if constexpr(UseDecodeGrid)

View File

@@ -122,6 +122,14 @@ struct UnifiedAttentionKernel
ck_tile::index_t window_size_left = -1;
ck_tile::index_t window_size_right = -1;
bool is_top_left = false;
// Learnable per-query-head attention sink. When non-null, points
// at `float[num_head_q]` on device; the softmax denominator gains
// one virtual key per Q head with logit `sink_ptr[h]` and an
// all-zero V row. Default `nullptr` reproduces the classic
// no-sink softmax. The pipeline does not yet read this pointer —
// it is currently pure payload threaded through kargs.
const void* sink_ptr = nullptr;
};
struct UnifiedAttentionVarlenKargs : UnifiedAttentionCommonKargs
@@ -196,7 +204,8 @@ struct UnifiedAttentionKernel
bool cache_ptr_int32_overflow_possible = false,
ck_tile::index_t window_size_left = -1,
ck_tile::index_t window_size_right = -1,
bool is_top_left = false)
bool is_top_left = false,
const void* sink_ptr = nullptr)
{
// Fuse the Q/K FP8 descales into `scale_s` so the softmax sees a
// single combined scalar — matches the Triton FP8 reference
@@ -237,7 +246,8 @@ struct UnifiedAttentionKernel
output_stride_1,
window_size_left,
window_size_right,
is_top_left},
is_top_left,
sink_ptr},
block_tables_ptr,
block_table_stride,
seq_lens_ptr,