diff --git a/example/ck_tile/42_unified_attention/example_unified_attention.cpp b/example/ck_tile/42_unified_attention/example_unified_attention.cpp index d5a7d88abd..88788fa51c 100644 --- a/example/ck_tile/42_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/example_unified_attention.cpp @@ -653,6 +653,26 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) args.window_size_right = problem.mask.right; args.is_top_left = (problem.mask.type == mask_enum::mask_top_left); + // Mirror the per-Q-head sink vector to device memory and hand the + // pointer to kargs. Empty `problem.sinks` leaves `args.sink_ptr` as + // its default `nullptr` — matches the classic no-sink convention. + // The buffer must outlive `ck_tile::unified_attention(args, ...)` + // below, so declare it in `run_impl`'s scope alongside seq_lens_buf + // / block_tables_buf. We use the `DeviceMem` default-ctor + Realloc + // idiom (same as the grouped-gemm examples) because `DeviceMem` + // owns a HIP allocation but defines no copy/move, so `sink_buf = + // DeviceMem(size)` would shallow-copy and double-free on scope + // exit. The device-side kernel does not dereference `sink_ptr` yet; + // this commit only verifies the pointer survives the round-trip + // through kargs without observable behaviour change. + ck_tile::DeviceMem sink_buf; + if(!problem.sinks.empty()) + { + sink_buf.Realloc(problem.sinks.size() * sizeof(float)); + sink_buf.ToDevice(problem.sinks.data()); + args.sink_ptr = sink_buf.GetDeviceBuffer(); + } + args.num_blks = problem.num_blks; args.q_ptr = q_buf.GetDeviceBuffer(); diff --git a/example/ck_tile/42_unified_attention/unified_attention.hpp b/example/ck_tile/42_unified_attention/unified_attention.hpp index a6965f0646..57297938ba 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention.hpp @@ -123,6 +123,18 @@ struct unified_attention_args index_t split_stride_lse_acc = 0; index_t nhead_stride_o_acc = 0; index_t nhead_stride_lse_acc = 0; + + // Learnable per-query-head attention sink (GPT-OSS / vLLM convention). + // When non-null, the device pointer addresses a contiguous + // `float[num_head_q]` of sink scalars; the softmax denominator gains + // one virtual key per Q head with logit `sink_ptr[h]` and an all-zero + // V row. Default `nullptr` reproduces the classic no-sink softmax. + // The host-facing CLI / aiter wrapper promotes bf16 sink tensors to + // fp32 host-side before pinning them on device — matches the + // `const void*` / `const float*` convention used by FMHA. + // Only the kargs plumbing reads the pointer today; the device-side + // pipeline does not yet consume it. + const void* sink_ptr = nullptr; }; std::ostream& operator<<(std::ostream& stream, diff --git a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp index c5f31dba6e..402eea861d 100644 --- a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp @@ -407,7 +407,8 @@ float unified_attention_kernel_launch(const unified_attention_args& args, args.cache_ptr_int32_overflow_possible, args.window_size_left, args.window_size_right, - args.is_top_left); + args.is_top_left, + args.sink_ptr); dim3 grids; if constexpr(UseDecodeGrid) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index aeb444c196..13b4515443 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -122,6 +122,14 @@ struct UnifiedAttentionKernel ck_tile::index_t window_size_left = -1; ck_tile::index_t window_size_right = -1; bool is_top_left = false; + + // Learnable per-query-head attention sink. When non-null, points + // at `float[num_head_q]` on device; the softmax denominator gains + // one virtual key per Q head with logit `sink_ptr[h]` and an + // all-zero V row. Default `nullptr` reproduces the classic + // no-sink softmax. The pipeline does not yet read this pointer — + // it is currently pure payload threaded through kargs. + const void* sink_ptr = nullptr; }; struct UnifiedAttentionVarlenKargs : UnifiedAttentionCommonKargs @@ -196,7 +204,8 @@ struct UnifiedAttentionKernel bool cache_ptr_int32_overflow_possible = false, ck_tile::index_t window_size_left = -1, ck_tile::index_t window_size_right = -1, - bool is_top_left = false) + bool is_top_left = false, + const void* sink_ptr = nullptr) { // Fuse the Q/K FP8 descales into `scale_s` so the softmax sees a // single combined scalar — matches the Triton FP8 reference @@ -237,7 +246,8 @@ struct UnifiedAttentionKernel output_stride_1, window_size_left, window_size_right, - is_top_left}, + is_top_left, + sink_ptr}, block_tables_ptr, block_table_stride, seq_lens_ptr,