[rocm-libraries] ROCm/rocm-libraries#5504 (commit 47f86c7)

[CK Tile] Add sink token gradient support in FMHA backward
 pass (#5504)

## Motivation

Adds sink token support to the FMHA backward kernel (dot_do_o pipeline):

## Technical Details

- Extend BlockFmhaBwdOGradDotOPipelineProblem with LSEDataType
- Add sink_ptr/d_sink_ptr/lse_ptr/nhead to FmhaBwdOGradDotOCommonKargs
- Compute per-head sink gradient via atomic accumulation in the pipeline
- Update example runner with reference validation for sink gradient

## Test Plan

Add new test case

## Test Result

WIP

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Linjun-AMD
2026-04-02 03:17:45 +00:00
committed by assistant-librarian[bot]
parent c1127a36f5
commit 08792e0b31
12 changed files with 380 additions and 130 deletions

View File

@@ -533,6 +533,7 @@ using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDot
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
/* BlockSize = M0 = */ {F_bm0},
{F_hdim},
{F_mode},

View File

@@ -87,6 +87,7 @@ auto create_args(int argc, char* argv[])
"0",
"if set to 1 will use multi-buffer reduction strategy for dq, atomic operation "
"will not be used")
.insert("sink_grad", "0", "if set to 1, compute and validate sink token gradient")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "fmha_bwd.json", "json file name to dump results");
@@ -122,6 +123,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
bool deterministic = arg_parser.get_bool("deterministic");
std::string init_method = arg_parser.get_str("init");
uint32_t seed = arg_parser.get_uint32("seed");
bool sink_grad = arg_parser.get_bool("sink_grad");
ck_tile::stream_config stream_config{nullptr,
true,
@@ -154,6 +156,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
drop_offset,
drop_prefs,
mask_str,
sink_grad,
deterministic,
init_method,
seed,

View File

@@ -116,6 +116,9 @@ struct fmha_bwd_args
void* dv_ptr;
void* dbias_ptr;
void* dq_acc_ptr;
const void*
sink_ptr; // sink scores [batch, nhead] in log-space (LSEDataType); nullptr disables sink
void* d_sink_ptr; // sink gradient output [nhead] (LSEDataType); nullptr disables sink gradient
// Usage notes for sequence length pointer parameters:
//
@@ -362,11 +365,15 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
args.do_ptr,
args.d_ptr,
args.lse_ptr,
args.sink_ptr,
args.d_sink_ptr,
args.p_undrop,
args.seqstart_q_ptr,
args.seqlen_q_ptr,
args.cu_seqlen_q_ptr,
args.hdim_v,
args.nhead_q,
args.stride_do,
args.stride_o,
args.nhead_stride_do,
@@ -378,9 +385,13 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
args.do_ptr,
args.d_ptr,
args.lse_ptr,
args.sink_ptr,
args.d_sink_ptr,
args.p_undrop,
args.seqlen_q,
args.hdim_v,
args.nhead_q,
args.stride_do,
args.stride_o,
args.nhead_stride_do,

View File

@@ -77,6 +77,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
uint64_t drop_offset,
bool drop_prefs,
std::string mask_str,
bool sink_grad, // if true, compute and validate sink gradient
bool deterministic,
std::string init_method,
uint32_t seed,
@@ -284,6 +285,16 @@ bwd_result fmha_bwd_run(mode_enum mode,
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
ck_tile::HostTensor<LSEDataType> lse_host(
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<LSEDataType> sink_host(
sink_grad ? std::array<ck_tile::index_t, 2>{shape_batch, nhead}
: std::array<ck_tile::index_t, 2>{1, 1} /* dummy when sink is disabled */);
if(sink_grad)
{
std::uniform_real_distribution<float> sink_dist(30.0f, 60.0f);
sink_host.ForEach([&](auto& self, auto i) {
self(i) = static_cast<LSEDataType>(sink_dist(random_engine));
});
}
ck_tile::HostTensor<DDataType> d_host(
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<RandValOutputDataType> randval_host(
@@ -301,6 +312,12 @@ bwd_result fmha_bwd_run(mode_enum mode,
use_dbias
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<LSEDataType> d_sink_host(sink_grad ? std::array<ck_tile::index_t, 1>{nhead}
: std::array<ck_tile::index_t, 1>{0});
if(sink_grad)
{
d_sink_host.ForEach([&](auto& self, auto i) { self(i) = 0; });
}
ck_tile::HostTensor<AccDataType> dq_acc_host(
std::array<ck_tile::index_t, 5>{shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q});
@@ -361,11 +378,13 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sink_buf(sink_grad ? sink_host.get_element_space_size_in_bytes() : 0);
ck_tile::DeviceMem d_buf(d_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dq_buf(dq_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dk_buf(dk_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dv_buf(dv_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_sink_buf(sink_grad ? d_sink_host.get_element_space_size_in_bytes() : 0);
ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
@@ -396,6 +415,11 @@ bwd_result fmha_bwd_run(mode_enum mode,
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
alibi_slope_buf.ToDevice(alibi_slope_host.data());
if(sink_grad)
{
sink_buf.ToDevice(sink_host.data());
d_sink_buf.ToDevice(d_sink_host.data());
}
// clang-format off
auto layout_str = [&](bool permute){
@@ -415,7 +439,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
<< "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0]
<< "/" << seqlen_ks[0] << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale
<< ", bias:" << bias << ", dbias:" << use_dbias << ", p_drop:" << p_drop
<< ", s_randval:" << s_randval << ", deterministic:" << deterministic
<< (sink_grad ? ", sink:(rand[30,60], grad)" : "") << ", s_randval:" << s_randval
<< ", deterministic:" << deterministic
<< (deterministic
? std::string(", workspace:") + std::to_string(workspace_size_in_megabytes) +
"MiB|" + std::to_string(nsplits) + "splits"
@@ -479,7 +504,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
const void* seqlen_q_ptr_dev = use_qpadding ? seqlen_q_dev.GetDeviceBuffer() : nullptr;
const void* seqlen_k_ptr_dev = use_kpadding ? seqlen_k_dev.GetDeviceBuffer() : nullptr;
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
@@ -495,6 +519,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
dv_buf.GetDeviceBuffer(),
dbias_buf.GetDeviceBuffer(),
dq_acc_buf.GetDeviceBuffer(),
sink_buf.GetDeviceBuffer(),
d_sink_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
seqlen_q_ptr_dev,
@@ -589,6 +615,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
std::vector<ck_tile::HostTensor<RandValOutputDataType>> randval_host_refs;
std::vector<ck_tile::HostTensor<AccDataType>> p_hp_host_refs;
std::vector<ck_tile::HostTensor<GemmDataType>> p_lp_host_refs;
std::vector<ck_tile::HostTensor<AccDataType>> p_sink_host_refs;
randval_buf.FromDevice(randval_host.data());
@@ -765,6 +792,46 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::reference_batched_softmax<AccDataType, LSEDataType, AccDataType>(
s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref);
// Incorporate sink token into the softmax distribution (reference computation).
// The sink acts as an extra key whose score is sink_host(wb, i_h) (in log-space),
// which is a per-head random value in [30, 60].
// lse_new = log(exp(lse_old) + exp(sink))
// P_new = P_old * exp(lse_old - lse_new) (rescaled token attention)
// P_sink = exp(sink - lse_new) (sink attention weight)
ck_tile::HostTensor<AccDataType> p_sink_host_ref(
sink_grad ? std::array<ck_tile::index_t, 2>{nhead, real_seqlen_q}
: std::array<ck_tile::index_t, 2>{0, 0});
if(sink_grad)
{
for(int i_h = 0; i_h < nhead; ++i_h)
{
AccDataType sink_val = sink_host(wb, i_h);
for(int i_q = 0; i_q < real_seqlen_q; ++i_q)
{
// Use numerically stable log-sum-exp: lse_new = log(exp(lse_old)+exp(sink))
// = max(lse_old, sink) + log(1 + exp(min - max))
// This handles lse_old = -inf (fully-masked rows) without producing NaN:
// if lse_old=-inf: max=sink, min=-inf, exp(-inf-sink)=0, lse_new=sink
// It also avoids exp(lse_old) overflow when lse_old is large.
// p_scale = exp(lse_old - lse_new) [fraction kept by regular tokens]
// p_sink = exp(sink - lse_new) [sink attention weight]
AccDataType lse_old = lse_host_ref(i_h, i_q);
AccDataType hi = lse_old > sink_val ? lse_old : sink_val;
AccDataType lo = lse_old > sink_val ? sink_val : lse_old;
AccDataType lse_new =
hi + ck_tile::log(AccDataType(1) + ck_tile::exp(lo - hi));
AccDataType p_scale = ck_tile::exp(lse_old - lse_new);
lse_host_ref(i_h, i_q) = lse_new;
for(int i_k = 0; i_k < real_seqlen_k; ++i_k)
p_hp_host_ref(i_h, i_q, i_k) *= p_scale;
p_sink_host_ref(i_h, i_q) = ck_tile::exp(sink_val - lse_new);
}
}
}
if(p_drop > 0)
{
p_dropped_hp_host_ref = p_hp_host_ref;
@@ -823,6 +890,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
o_host_refs.push_back(o_host_ref);
p_hp_host_refs.push_back(p_hp_host_ref);
p_lp_host_refs.push_back(p_lp_host_ref);
p_sink_host_refs.push_back(p_sink_host_ref);
if(p_drop > 0)
{
randval_host_refs.push_back(randval_host_ref);
@@ -842,6 +910,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
o_buf.ToDevice(o_host.data());
lse_buf.ToDevice(lse_host.data());
dbias_buf.SetZero();
if(sink_grad)
d_sink_buf.SetZero();
if(launcher.needs_zero_dq_acc)
dq_acc_buf.SetZero();
@@ -853,10 +923,19 @@ bwd_result fmha_bwd_run(mode_enum mode,
dk_buf.FromDevice(dk_host.data());
dv_buf.FromDevice(dv_host.data());
dbias_buf.FromDevice(dbias_host.data());
if(sink_grad)
d_sink_buf.FromDevice(d_sink_host.data());
// Track the index into reference vectors (may differ from wb if batches were skipped)
ck_tile::index_t ref_idx = 0;
// validation sink accumulator: global over batch, shape [nhead]
ck_tile::HostTensor<AccDataType> d_sink_host_ref(
sink_grad ? std::array<ck_tile::index_t, 1>{nhead}
: std::array<ck_tile::index_t, 1>{0});
if(sink_grad)
d_sink_host_ref.ForEach([&](auto& self, auto i) { self(i) = 0; });
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
// When padding is enabled, use logical lengths instead of computing from padded
@@ -932,6 +1011,30 @@ bwd_result fmha_bwd_run(mode_enum mode,
ds_hp_host_ref.mDesc.get_lengths()[1],
ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency());
if(sink_grad)
{
// Reference: dSink[h] = -sum_q( P_sink[h,q] * D[h,q] )
// where D[h,q] = sum_j(dO[h,q,j] * O[h,q,j]) * p_undrop
for(int i_h = 0; i_h < nhead; ++i_h)
{
AccDataType d_sink_head_acc = 0;
for(int i_q = 0; i_q < real_seqlen_q; ++i_q)
{
AccDataType do_dot_o = 0;
for(int o = 0; o < hdim_v; o++)
{
do_dot_o +=
ck_tile::type_convert<AccDataType>(do_host_ref(i_h, i_q, o)) *
ck_tile::type_convert<AccDataType>(
o_host_refs[ref_idx](i_h, i_q, o)) *
p_undrop;
}
d_sink_head_acc += -p_sink_host_refs[ref_idx](i_h, i_q) * do_dot_o;
}
d_sink_host_ref(i_h) += d_sink_head_acc;
}
}
if(use_dbias)
{
dbias_host_ref = ds_hp_host_ref.template CopyAsType<BiasGradDataType>();
@@ -1044,6 +1147,17 @@ bwd_result fmha_bwd_run(mode_enum mode,
ref_idx++;
}
if(pass && sink_grad)
{
auto [rtol, atol] = get_elimit<DataTypeConfig>(hdim_q, hdim_v);
bool dsink_pass = ck_tile::check_err(d_sink_host,
d_sink_host_ref,
std::string("Error: SinkGrad Incorrect results!"),
rtol,
atol);
pass &= dsink_pass;
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}

View File

@@ -39,7 +39,6 @@ function print_log_header(){
#run verification tests
time example/ck_tile/01_fmha/script/smoke_test_fwd.sh
time example/ck_tile/01_fmha/script/smoke_test_bwd.sh
time example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh
#run performance benchmarks
export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log"

View File

@@ -69,6 +69,28 @@ test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0
test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
test_h_s_mask -prec=bf16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
done
# sink gradient tests: same coverage as main tests but with -sink_grad=1
for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do
for hdim in 64 128 256 ; do
for mode in 0 1 ; do
for bias in "n" "a" ; do
for p_drop in 0.0 0.2 ; do
test_h_s_mask -prec=$prec -d=$hdim -bias=$bias -dbias=0 -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=0 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -sink_grad=1
done
done
done
done
done
done
# sink gradient additional cases: non-standard hdim
for hdim in 40 48 72 96 ; do
test_h_s_mask -prec=fp16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=$KNAME $COMMON_ARGS -sink_grad=1
test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS -sink_grad=1
test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS -sink_grad=1
done
set +x
new_fails_count=0

View File

@@ -235,6 +235,64 @@ run_padding_basic_boundary_tests() {
done
}
# Sink-specific mask pattern tests (sliding window + sink token).
run_sink_mask_tests() {
# window_size[2,0], sink_size=2 (top-left causal + sink)
# before: after:
# 1 * * * * * * * 1 * * * * * * *
# 1 1 * * * * * * 1 1 * * * * * *
# 1 1 1 * * * * * 1 1 1 * * * * *
# * 1 1 1 * * * * 1 1 1 1 * * * *
# * * 1 1 1 * * * 1 1 1 1 1 * * *
# * * * 1 1 1 * * 1 1 * 1 1 1 * *
# * * * * 1 1 1 * 1 1 * * 1 1 1 *
# * * * * * 1 1 1 1 1 * * * 1 1 1
run_exe -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:2,0,2
run_exe -prec=bf16 -mode=0 -b=2 -h=2 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:2,0,2
# window_size[0,3], sink_size=2 (top-left + sink)
# before: after:
# 1 1 1 1 * * * * 1 1 1 1 * * * *
# * 1 1 1 1 * * * 1 1 1 1 1 * * *
# * * 1 1 1 1 * * 1 1 1 1 1 1 * *
# * * * 1 1 1 1 * 1 1 * 1 1 1 1 *
# * * * * 1 1 1 1 1 1 * * 1 1 1 1
run_exe -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:0,3,2
run_exe -prec=bf16 -mode=1 -b=2 -h=2 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:0,3,2
# window_size[1,0], sink_size=2 (bottom-right + sink)
# before: after:
# * * 1 1 * * * * 1 1 1 1 * * * *
# * * * 1 1 * * * 1 1 * 1 1 * * *
# * * * * 1 1 * * 1 1 * * 1 1 * *
# * * * * * 1 1 * 1 1 * * * 1 1 *
# * * * * * * 1 1 1 1 * * * * 1 1
run_exe -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:1,0,2
run_exe -prec=bf16 -mode=0 -b=2 -h=4 -d=128 -d_v=128 -s=2048 -s_k=2048 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:1,0,2
# window_size[2,0], sink_size=2 (bottom-right, group mode + sink)
run_exe -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:2,0,2
run_exe -prec=bf16 -mode=1 -b=2 -h=2 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:2,0,2
# window_size[-1,1], sink_size=2 (bottom-right, large seqlen + sink)
run_exe -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:-1,1,2
run_exe -prec=bf16 -mode=1 -b=1 -h=2 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:-1,1,2
}
# init_sink tests: validate sink token initialization across prec/hdim/mode.
run_sink_init_tests() {
for prec in "fp16" "bf16" ; do
for hdim in 64 128 256 ; do
for mode in 0 1 ; do
for mask in 0 1 ; do
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -d=$hdim -d_v=$hdim -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS -init_sink=1 -mask=$mask
run_exe -prec=$prec -mode=$mode -b=2 -h=4 -d=$hdim -d_v=$hdim -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -kname=$KNAME $COMMON_ARGS -init_sink=1 -mask=$mask
done
done
done
done
}
set -x
run_fp16_bf16_tests
@@ -242,6 +300,8 @@ run_padding_smoke_tests
run_padding_basic_boundary_tests
run_fp8bf16_tests
run_fp8fp32_tests
run_sink_mask_tests
run_sink_init_tests
if [ $TEST_APPENDKV -eq 1 ] ; then
run_fp16_appendkv_tests

View File

@@ -1,93 +0,0 @@
#!/bin/bash
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# TODO: run this script from CK root or build directory
#EXE="/code/composable_kernel/build/bin/tile_example_fmha_fwd"
set -euo pipefail
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
EXE_NAME=tile_example_fmha_fwd
EXE="$(find . -name $EXE_NAME -type f | head -n 1)"
KNAME=1
GPU_arch=$GPU_arch
if [ -z "$GPU_arch" ] ; then
GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}')
fi
set -x
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:2,0,2
# window_size[2,0], sink_size = 2
# x=1/y=3
# 1 * * * * * * * 1 * * * * * * *
# 1 1 * * * * * * 1 1 * * * * * *
# 1 1 1 * * * * * ----> 1 1 1 * * * * *
# * 1 1 1 * * * * 1 1 1 1 * * * *
# * * 1 1 1 * * * 1 1 1 1 1 * * *
# * * * 1 1 1 * * 1 1 * 1 1 1 * *
# * * * * 1 1 1 * 1 1 * * 1 1 1 *
# * * * * * 1 1 1 1 1 * * * 1 1 1
# l=2/r=0(tl) l=2/r=0/s=2(tl)
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:0,3,2 #-mask=b:3,0,2
# x=4/y=1
# 1 1 1 1 * * * * 1 1 1 1 * * * *
# * 1 1 1 1 * * * 1 1 1 1 1 * * *
# * * 1 1 1 1 * * ----> 1 1 1 1 1 1 * *
# * * * 1 1 1 1 * 1 1 * 1 1 1 1 *
# * * * * 1 1 1 1 1 1 * * 1 1 1 1
# l=0/r=3(tl) l=0/r=3/s=2(tl)
# l=3/r=0(br) l=3/r=0/s=2(br)
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:1,0,2
# x=4/y=-1
# * * 1 1 * * * * 1 1 1 1 * * * *
# * * * 1 1 * * * 1 1 * 1 1 * * *
# * * * * 1 1 * * ----> 1 1 * * 1 1 * *
# * * * * * 1 1 * 1 1 * * * 1 1 *
# * * * * * * 1 1 1 1 * * * * 1 1
# l=1/r=0(br) l=1/r=0/s=2(br)
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:2,0,2
# x=-1/y=5
# * * * * * * * * * * * *
# * * * * * * * * * * * *
# 1 * * * * * 1 * * * * *
# 1 1 * * * * 1 1 * * * *
# 1 1 1 * * * ----> 1 1 1 * * *
# * 1 1 1 * * 1 1 1 1 * *
# * * 1 1 1 * 1 1 1 1 1 *
# * * * 1 1 1 1 1 * 1 1 1
# l=2/r=0(br) l=2/r=0/s=2(br)
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:-1,1,2
# x=-1/y=8
# * * * * * * * * * *
# * * * * * * * * * *
# 1 * * * * ----> 1 * * * *
# 1 1 * * * 1 1 * * *
# 1 1 1 * * 1 1 1 * *
# 1 1 1 1 * 1 1 1 1 *
# 1 1 1 1 1 1 1 1 1 1
# 1 1 1 1 1 1 1 1 1 1
# l=2/r=0(br) l=2/r=0/s=2(br)
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=0
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1