mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
UA Sink: add the sink parameter to the interfaces + update examples
This commit is contained in:
@@ -653,6 +653,26 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
args.window_size_right = problem.mask.right;
|
||||
args.is_top_left = (problem.mask.type == mask_enum::mask_top_left);
|
||||
|
||||
// Mirror the per-Q-head sink vector to device memory and hand the
|
||||
// pointer to kargs. Empty `problem.sinks` leaves `args.sink_ptr` as
|
||||
// its default `nullptr` — matches the classic no-sink convention.
|
||||
// The buffer must outlive `ck_tile::unified_attention(args, ...)`
|
||||
// below, so declare it in `run_impl`'s scope alongside seq_lens_buf
|
||||
// / block_tables_buf. We use the `DeviceMem` default-ctor + Realloc
|
||||
// idiom (same as the grouped-gemm examples) because `DeviceMem`
|
||||
// owns a HIP allocation but defines no copy/move, so `sink_buf =
|
||||
// DeviceMem(size)` would shallow-copy and double-free on scope
|
||||
// exit. The device-side kernel does not dereference `sink_ptr` yet;
|
||||
// this commit only verifies the pointer survives the round-trip
|
||||
// through kargs without observable behaviour change.
|
||||
ck_tile::DeviceMem sink_buf;
|
||||
if(!problem.sinks.empty())
|
||||
{
|
||||
sink_buf.Realloc(problem.sinks.size() * sizeof(float));
|
||||
sink_buf.ToDevice(problem.sinks.data());
|
||||
args.sink_ptr = sink_buf.GetDeviceBuffer();
|
||||
}
|
||||
|
||||
args.num_blks = problem.num_blks;
|
||||
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
|
||||
@@ -123,6 +123,18 @@ struct unified_attention_args
|
||||
index_t split_stride_lse_acc = 0;
|
||||
index_t nhead_stride_o_acc = 0;
|
||||
index_t nhead_stride_lse_acc = 0;
|
||||
|
||||
// Learnable per-query-head attention sink (GPT-OSS / vLLM convention).
|
||||
// When non-null, the device pointer addresses a contiguous
|
||||
// `float[num_head_q]` of sink scalars; the softmax denominator gains
|
||||
// one virtual key per Q head with logit `sink_ptr[h]` and an all-zero
|
||||
// V row. Default `nullptr` reproduces the classic no-sink softmax.
|
||||
// The host-facing CLI / aiter wrapper promotes bf16 sink tensors to
|
||||
// fp32 host-side before pinning them on device — matches the
|
||||
// `const void*` / `const float*` convention used by FMHA.
|
||||
// Only the kargs plumbing reads the pointer today; the device-side
|
||||
// pipeline does not yet consume it.
|
||||
const void* sink_ptr = nullptr;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream,
|
||||
|
||||
@@ -407,7 +407,8 @@ float unified_attention_kernel_launch(const unified_attention_args& args,
|
||||
args.cache_ptr_int32_overflow_possible,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.is_top_left);
|
||||
args.is_top_left,
|
||||
args.sink_ptr);
|
||||
|
||||
dim3 grids;
|
||||
if constexpr(UseDecodeGrid)
|
||||
|
||||
@@ -122,6 +122,14 @@ struct UnifiedAttentionKernel
|
||||
ck_tile::index_t window_size_left = -1;
|
||||
ck_tile::index_t window_size_right = -1;
|
||||
bool is_top_left = false;
|
||||
|
||||
// Learnable per-query-head attention sink. When non-null, points
|
||||
// at `float[num_head_q]` on device; the softmax denominator gains
|
||||
// one virtual key per Q head with logit `sink_ptr[h]` and an
|
||||
// all-zero V row. Default `nullptr` reproduces the classic
|
||||
// no-sink softmax. The pipeline does not yet read this pointer —
|
||||
// it is currently pure payload threaded through kargs.
|
||||
const void* sink_ptr = nullptr;
|
||||
};
|
||||
|
||||
struct UnifiedAttentionVarlenKargs : UnifiedAttentionCommonKargs
|
||||
@@ -196,7 +204,8 @@ struct UnifiedAttentionKernel
|
||||
bool cache_ptr_int32_overflow_possible = false,
|
||||
ck_tile::index_t window_size_left = -1,
|
||||
ck_tile::index_t window_size_right = -1,
|
||||
bool is_top_left = false)
|
||||
bool is_top_left = false,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
// Fuse the Q/K FP8 descales into `scale_s` so the softmax sees a
|
||||
// single combined scalar — matches the Triton FP8 reference
|
||||
@@ -237,7 +246,8 @@ struct UnifiedAttentionKernel
|
||||
output_stride_1,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
is_top_left},
|
||||
is_top_left,
|
||||
sink_ptr},
|
||||
block_tables_ptr,
|
||||
block_table_stride,
|
||||
seq_lens_ptr,
|
||||
|
||||
Reference in New Issue
Block a user