[CK_TILE][FMHA] Enable gpt-oss sink (#3490)

* Enable gptoss sink

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* add gptoss sink test

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* update CHANGELOG.md

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* fix test args error

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Update test_fmha_fwd.cpp

* update sink test

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Revert "update sink test"

This reverts commit 970b4f1686.

* update sink test

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* update valid sink_v in splitkv pipeline

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp

* Update example_fmha_fwd.cpp

* fix lse error

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* fix clangformat error

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* fix aiter scale error

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Update block_fmha_pipeline_qr_ks_vs.hpp

* div scale_s for sink_value

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Update fmha_fwd_runner.hpp

* update sink_value with bias

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp

* Fix typo in dropout parameter in fmha_batch_prefill_kernel

* Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp

* Update example_fmha_fwd.cpp

* Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* optimized some code

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* fix splitkv error

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* update sink reference

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Update fmha_fwd_runner.hpp

* Update smoke_test_fwd_sink.sh

---------

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
Linjun-AMD
2026-01-14 21:32:06 +08:00
committed by GitHub
parent 693ff3bbb3
commit 717ed0b59f
17 changed files with 487 additions and 110 deletions

View File

@@ -114,7 +114,8 @@ auto create_args(int argc, char* argv[])
.insert("kv_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.");
"Comma-separated list of length 'b'. If empty, no override.")
.insert("init_sink", "0", "value to init the output tensor sink value for validation");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -157,6 +158,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
std::string init_method = arg_parser.get_str("init");
uint32_t seed = arg_parser.get_uint32("seed");
int init_sink_value = arg_parser.get_int("init_sink");
ck_tile::stream_config stream_config{nullptr,
true,
@@ -203,6 +205,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
init_method,
seed,
do_validation,
init_sink_value,
stream_config,
json);
}

View File

@@ -230,6 +230,7 @@ struct fmha_fwd_args
// array [batch + 1]. (Used with padding)
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
const void* sink_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
@@ -317,6 +318,7 @@ struct fmha_fwd_pagedkv_args
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
const void* sink_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
@@ -400,6 +402,7 @@ struct fmha_fwd_splitkv_args
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
const void* sink_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
@@ -476,6 +479,7 @@ struct fmha_fwd_appendkv_args
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache)
const void* sink_ptr;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
@@ -519,6 +523,7 @@ struct fmha_batch_prefill_args
// 1) +
// kargs.kv_last_page_lens[b]
const void* seqstart_q_ptr;
const void* sink_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
@@ -638,7 +643,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.s_randval,
args.drop_seed_offset,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr);
args.cu_seqlen_k_ptr,
args.sink_ptr);
}
else
{ // create batch mode kernel arguments
@@ -688,7 +694,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.s_randval,
args.drop_seed_offset,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr);
args.cu_seqlen_k_ptr,
args.sink_ptr);
}
}();
@@ -848,7 +855,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
args.window_size_right,
args.sink_size,
args.mask_type,
args.min_seqlen_q);
args.min_seqlen_q,
args.sink_ptr);
}
else
{ // create batch mode kernel arguments
@@ -893,7 +901,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type);
args.mask_type,
args.sink_ptr);
}
}();
@@ -960,7 +969,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type);
args.mask_type,
args.sink_ptr);
}
else
{ // create batch mode kernel arguments
@@ -1008,7 +1018,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type);
args.mask_type,
args.sink_ptr);
}
}();
@@ -1187,7 +1198,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
args.drop_seed_offset,
args.sink_ptr);
}
else
{ // create batch mode kernel arguments
@@ -1239,7 +1251,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
args.drop_seed_offset,
args.sink_ptr);
}
}();

View File

@@ -149,6 +149,28 @@ int override_num_splits_if_necessary(
return num_splits;
}
template <typename SMPLComputeDataType>
void copy_attention_scores_with_sink(const ck_tile::HostTensor<SMPLComputeDataType>& s_host_ref,
const ck_tile::HostTensor<SMPLComputeDataType>& sink_host,
ck_tile::HostTensor<SMPLComputeDataType>& s_with_sinks_ref,
ck_tile::index_t nhead,
ck_tile::index_t real_seqlen_q,
ck_tile::index_t real_seqlen_k)
{
for(auto i_h = 0; i_h < nhead; i_h++)
{
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
{
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
{
s_with_sinks_ref(i_h, i_r, i_c) = s_host_ref(i_h, i_r, i_c);
}
// Append sink token at the end of each row
s_with_sinks_ref(i_h, i_r, real_seqlen_k) = sink_host(i_h);
}
}
}
template <typename DataTypeConfig>
fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::index_t batch,
@@ -184,6 +206,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
std::string init_method,
uint32_t seed,
int do_validation,
int init_sink_value,
const ck_tile::stream_config& stream_config,
std::optional<std::string> json = std::nullopt)
{
@@ -527,6 +550,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<SMPLComputeDataType> sink_host({nhead});
ck_tile::HostTensor<KDataType> k_host(
0 < page_block_size
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
@@ -609,6 +633,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, next_seed()}(
bias_host);
}
else if(init_method == "ni")
{
ck_tile::FillNormalDistributionIntegerValue<QDataType>{-3.f, 3.f, next_seed()}(q_host);
@@ -695,10 +720,17 @@ fwd_result fmha_fwd_run(mode_enum mode,
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine);
if(init_sink_value != 0)
{
// sink is initialized to a fixed integer value for easy debugging and use 30 to 60 range
// for close to rowmax values.
ck_tile::FillUniformDistributionIntegerValue<SMPLComputeDataType>{30.f, 60.f, next_seed()}(
sink_host);
}
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sink_buf(sink_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
@@ -743,6 +775,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
v_buf.ToDevice(v_host.data());
sink_buf.ToDevice(sink_host.data());
knew_buf.ToDevice(knew_host.data());
vnew_buf.ToDevice(vnew_host.data());
bias_buf.ToDevice(bias_host.data());
@@ -971,7 +1004,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer();
args.v_ptr = v_buf.GetDeviceBuffer();
if(init_sink_value != 0)
args.sink_ptr = sink_buf.GetDeviceBuffer();
else
args.sink_ptr = nullptr;
args.batch = batch;
args.seqlen_q = shape_seqlen_q; // unused in group mode
args.hdim_q = hdim_q;
@@ -1675,19 +1711,57 @@ fwd_result fmha_fwd_run(mode_enum mode,
mask.type == mask_enum::mask_top_left));
}
const ck_tile::HostTensor<SaccDataType> masked_s_host_ref = s_host_ref;
if(lse)
if(init_sink_value != 0)
{
ck_tile::
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref);
// Create extended tensor with sink token
ck_tile::HostTensor<SMPLComputeDataType> s_with_sinks_ref(
{nhead, real_seqlen_q, real_seqlen_k + 1});
// Copy original attention scores and append sink values
copy_attention_scores_with_sink(
s_host_ref, sink_host, s_with_sinks_ref, nhead, real_seqlen_q, real_seqlen_k);
// Compute softmax on extended tensor
ck_tile::HostTensor<PDataType> p_extended(
{nhead, real_seqlen_q, real_seqlen_k + 1});
if(lse)
{
ck_tile::reference_batched_softmax<SMPLComputeDataType,
SMPLComputeDataType,
PDataType>(
s_with_sinks_ref, p_extended, p_compute_element_func, lse_host_ref);
}
else
{
ck_tile::reference_batched_softmax<SMPLComputeDataType,
SMPLComputeDataType,
PDataType>(
s_with_sinks_ref, p_extended, p_compute_element_func);
}
// Extract only the original columns (exclude sink token column)
p_host_ref.ForEach(
[&](auto& self, auto idx) { self(idx) = p_extended(idx[0], idx[1], idx[2]); });
}
else
{
ck_tile::
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
// No sink tokens - compute softmax directly
if(lse)
{
ck_tile::reference_batched_softmax<SMPLComputeDataType,
SMPLComputeDataType,
PDataType>(
s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref);
}
else
{
ck_tile::reference_batched_softmax<SMPLComputeDataType,
SMPLComputeDataType,
PDataType>(
s_host_ref, p_host_ref, p_compute_element_func);
}
}
if(p_drop > 0)
{
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(

View File

@@ -84,3 +84,10 @@ $EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -l
# 1 1 1 1 1 1 1 1 1 1
# l=2/r=0(br) l=2/r=0/s=2(br)
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=0
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1