Add smoke tests for SWA edge cases and performance gating

This commit is contained in:
Damien Lejeune
2026-05-08 09:30:56 +00:00
parent 5afd97ff5b
commit f438cef286
4 changed files with 254 additions and 60 deletions

View File

@@ -0,0 +1,91 @@
#!/bin/bash
# edge_test_swa.sh - Numerical edge cases for SWA in CK-tile unified attention.
#
# Tests the corner cases that random-shape sweeps in smoke_test_swa.sh miss:
# - window=1 (every Q row attends to its own K position only)
# - window > seq_k with right=0 (degenerates to plain causal)
# - explicit b:0,0 (alternative spelling of diagonal-only)
# - decode shapes (q=1, kv>>1) — exercises the SWA path on a single-token Q
# (today this is dispatched to the large-tier kernel via the
# `if(is_local) tier = tile_tier::large` hack; correctness should hold even
# though it's wasteful)
#
# Same convention as smoke_test_swa.sh: every test must pass, exit code is the
# number of failures.
set -uo pipefail
EXE_NAME=tile_example_unified_attention
EXE="${EXE:-$(find . -name $EXE_NAME -type f | head -n 1)}"
if [ -z "${EXE:-}" ] || [ ! -x "$EXE" ]; then
echo "ERROR: $EXE_NAME not found. Set EXE=/path/to/$EXE_NAME or run from build dir." >&2
exit 2
fi
# Same deterministic fixture as smoke_test_swa.sh.
COMMON="-prec=bf16 -seed=17 -verify=1 -warmup=0 -repeat=1 -varlen=0 -nb=1024 -page_blk_size=128"
BASELINE_A="-d=128 -h_k=8 -nqpkv=1 -b=4 -s=512 -s_k=512 -query_lens=128,128,128,128 -kv_lens=128,128,128,128"
BASELINE_B="-d=64 -h_k=1 -nqpkv=8 -b=4 -s=512 -s_k=512 -query_lens=400,256,512,128 -kv_lens=400,256,512,128"
# Decode-shape fixtures (q=1, kv=512). SWA dispatcher forces tile_tier::large so
# these go through the prefill kernel even though they're decode shapes — the
# point of these tests is that the large-tier IsLocal=true kernel still produces
# correct numerics on a single-token query.
DECODE_A="-d=128 -h_k=8 -nqpkv=1 -b=4 -s=1 -s_k=512 -query_lens=1,1,1,1 -kv_lens=512,512,512,512"
DECODE_B="-d=64 -h_k=1 -nqpkv=8 -b=4 -s=1 -s_k=512 -query_lens=1,1,1,1 -kv_lens=512,512,512,512"
TESTS=(
# Edge 1: window=1 — diagonal-only attention. Smallest non-zero window.
# xb:1 decodes to left=0, right=0 via window/2 split.
"baseA xb:1 |$BASELINE_A -mask=xb:1"
"baseB xb:1 |$BASELINE_B -mask=xb:1"
# Edge 2: window > seq_k with right=0 — IsLocal=true kernel must still
# produce the same answer as the IsLocal=false causal kernel
# (verified independently against the reference).
"baseA b:8192,0 |$BASELINE_A -mask=b:8192,0"
"baseB b:8192,0 |$BASELINE_B -mask=b:8192,0"
# Edge 3: alternative diagonal-only spelling via explicit b:0,0.
# Must produce identical numerics to xb:1 above.
"baseA b:0,0 |$BASELINE_A -mask=b:0,0"
"baseB b:0,0 |$BASELINE_B -mask=b:0,0"
# Edge 4: decode shapes (single-token query). The SWA mask trims the K range
# to a 64-wide window at the bottom-right corner of the (1, 512)
# attention matrix, so most of the kv tail is masked out.
"decode q=1 d128 xb:64|$DECODE_A -mask=xb:64"
"decode q=1 d64 xb:64|$DECODE_B -mask=xb:64"
)
n_pass=0
n_fail=0
for entry in "${TESTS[@]}"; do
name="${entry%%|*}"
args="${entry#*|}"
printf '== %-26s :: %s\n' "$name" "$args"
set +e
"$EXE" $COMMON $args > /tmp/swa_edge_out.$$ 2>&1
ret=$?
set -e
if [ $ret -eq 0 ]; then
echo " PASS"
n_pass=$((n_pass + 1))
else
echo " FAIL (rc=$ret). Tail of output:"
tail -3 /tmp/swa_edge_out.$$ | sed 's/^/ /'
n_fail=$((n_fail + 1))
fi
rm -f /tmp/swa_edge_out.$$
done
echo
echo "Summary:"
printf ' passed : %d\n' $n_pass
printf ' failed : %d\n' $n_fail
exit $n_fail

View File

@@ -0,0 +1,111 @@
#!/bin/bash
# perf_test_swa.sh - Perf gating test for SWA in CK-tile unified attention.
#
# Asserts that the SWA KV-block iteration clip (the unified_attention_kernel.hpp
# "Sliding-window-attention: tighten the KV-block iteration..." block) is
# actually firing and skipping out-of-window KV blocks.
#
# Strategy: on a long-context prefill shape (kv=8192, q=128) with a small SWA
# window (128), the kernel should iterate ~3 KV sub-blocks per Q-tile instead
# of ~128 for plain causal. We assert >= MIN_SPEEDUP wall-clock speedup.
#
# Measured speedup on MI350 today (gfx950) is 12-20x; we assert 5x to leave
# generous headroom for other GPUs / contention while still catching a Step D
# regression that re-iterates the full KV.
#
# Run with:
# ./perf_test_swa.sh
#
# Exit code:
# 0 = SWA met the speedup threshold for both shape families.
# 1 = SWA failed to meet the threshold somewhere.
# 2 = environment error (binary not found, parse failure, etc.)
set -uo pipefail
EXE_NAME=tile_example_unified_attention
EXE="${EXE:-$(find . -name $EXE_NAME -type f | head -n 1)}"
if [ -z "${EXE:-}" ] || [ ! -x "$EXE" ]; then
echo "ERROR: $EXE_NAME not found. Set EXE=/path/to/$EXE_NAME or run from build dir." >&2
exit 2
fi
# Speedup threshold; the actual measurement is 10-20x, so 5x is a generous
# regression guard.
MIN_SPEEDUP="${MIN_SPEEDUP:-5.0}"
# verify=0 because we trust smoke_test_swa.sh / edge_test_swa.sh for numerics
# and want time_kernel_=true to actually measure clock time, not host-ref work.
COMMON="-prec=bf16 -seed=17 -verify=0 -warmup=5 -repeat=20 -varlen=0 -nb=1024 -page_blk_size=128"
# Long-context prefill: kv=8192 with a 128-token query. This is the regime
# where SWA (window=128) is most lopsided vs causal (full lower triangle).
SHAPE_A="-d=128 -h_k=8 -nqpkv=1 -b=2 -s=128 -s_k=8192 -query_lens=128,128 -kv_lens=8192,8192"
SHAPE_B="-d=64 -h_k=1 -nqpkv=8 -b=2 -s=128 -s_k=8192 -query_lens=128,128 -kv_lens=8192,8192"
# Parse "<...>, <ms> ms, <...>" out of the kernel summary line.
# We grep -oP a single match to avoid partial reads of a large output.
extract_ms() {
grep -oP '\d+\.\d+(?= ms,)' | head -n 1
}
run_one() {
local label="$1"; shift
local out
out=$("$EXE" $COMMON "$@" 2>&1)
local ms
ms=$(echo "$out" | extract_ms)
if [ -z "$ms" ]; then
echo "ERROR: failed to extract ms from output of '$label'" >&2
echo "$out" | tail -10 >&2
exit 2
fi
printf '%s\n' "$ms"
}
# Returns 0 if "$1 / $2 >= $MIN_SPEEDUP", 1 otherwise.
check_speedup() {
awk -v c="$1" -v s="$2" -v m="$MIN_SPEEDUP" \
'BEGIN { sp = c / s; if (sp >= m) exit 0; else exit 1 }'
}
n_fail=0
overall_status=0
run_one_shape() {
local shape_name="$1"
local shape_args="$2"
echo "=== $shape_name ==="
local t_causal t_swa
t_causal=$(run_one "$shape_name causal" $shape_args -mask=b)
t_swa=$(run_one "$shape_name swa" $shape_args -mask=xb:128)
local speedup
speedup=$(awk -v c="$t_causal" -v s="$t_swa" 'BEGIN { printf "%.2f", c / s }')
printf ' causal : %8s ms\n' "$t_causal"
printf ' swa xb:128 : %8s ms\n' "$t_swa"
printf ' speedup : %sx (threshold %sx)\n' "$speedup" "$MIN_SPEEDUP"
if check_speedup "$t_causal" "$t_swa"; then
echo " PASS"
else
echo " FAIL: SWA was not >= ${MIN_SPEEDUP}x faster than causal."
echo " Most likely culprit: Step D (KV-block iteration clip in"
echo " unified_attention_kernel.hpp) was disabled or regressed,"
echo " leaving the SWA path iterating the full KV like causal."
n_fail=$((n_fail + 1))
overall_status=1
fi
echo
}
run_one_shape "d=128 MHA, q=128, kv=8192" "$SHAPE_A"
run_one_shape "d=64 GQA-8 (h_k=1), q=128, kv=8192" "$SHAPE_B"
if [ $overall_status -eq 0 ]; then
echo "All perf gates passed."
fi
exit $overall_status

View File

@@ -1,20 +1,15 @@
#!/bin/bash
# smoke_test_swa.sh - RED tests for Sliding Window Attention (SWA)
# in the CK-tile unified attention kernel.
# smoke_test_swa.sh - Sliding Window Attention (SWA) smoke tests for the
# CK-tile unified attention kernel.
#
# Each test entry is "EXPECT|EXTRA_ARGS" where EXPECT is GREEN or RED.
# GREEN: the test must currently pass; failing it is a regression.
# RED: the test must currently fail; passing it means SWA support landed
# and the test should be moved to GREEN.
# Each test entry is "NAME|EXTRA_ARGS"; every test must pass against the host
# reference. Failure exit code is the number of failed tests.
#
# Run with:
# ./smoke_test_swa.sh
#
# Exit code is the number of unexpected outcomes (0 = all matched expectation).
set -uo pipefail
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
EXE_NAME=tile_example_unified_attention
EXE="${EXE:-$(find . -name $EXE_NAME -type f | head -n 1)}"
if [ -z "${EXE:-}" ] || [ ! -x "$EXE" ]; then
@@ -34,67 +29,48 @@ BASELINE_A="-d=128 -h_k=8 -nqpkv=1 -b=4 -s=512 -s_k=512 -query_lens=128,128,128,
BASELINE_B="-d=64 -h_k=1 -nqpkv=8 -b=4 -s=512 -s_k=512 -query_lens=400,256,512,128 -kv_lens=400,256,512,128"
TESTS=(
# Causal regression guards (must pass).
"GREEN|baseA causal |$BASELINE_A -mask=b"
"GREEN|baseB causal |$BASELINE_B -mask=b"
# Causal regression guards.
"baseA causal |$BASELINE_A -mask=b"
"baseB causal |$BASELINE_B -mask=b"
# SWA via xformer-style window: kernel is now expected to honor the SWA
# window on both axes (per-pixel mask + KV-block iteration clip).
"GREEN|baseA xb:64 |$BASELINE_A -mask=xb:64"
"GREEN|baseA xb:128 |$BASELINE_A -mask=xb:128"
"GREEN|baseB xb:64 |$BASELINE_B -mask=xb:64"
"GREEN|baseB xb:128 |$BASELINE_B -mask=xb:128"
# SWA via xformer-style window: per-pixel mask + KV-block iteration clip.
"baseA xb:64 |$BASELINE_A -mask=xb:64"
"baseA xb:128 |$BASELINE_A -mask=xb:128"
"baseB xb:64 |$BASELINE_B -mask=xb:64"
"baseB xb:128 |$BASELINE_B -mask=xb:128"
# SWA via FA-style explicit left/right window.
"GREEN|baseA b:64,0 |$BASELINE_A -mask=b:64,0"
"GREEN|baseB b:64,0 |$BASELINE_B -mask=b:64,0"
"baseA b:64,0 |$BASELINE_A -mask=b:64,0"
"baseB b:64,0 |$BASELINE_B -mask=b:64,0"
)
n_green_pass=0
n_green_fail=0 # regressions
n_red_pass=0 # unexpected SWA passes (move to GREEN)
n_red_fail=0 # expected RED
n_pass=0
n_fail=0
for entry in "${TESTS[@]}"; do
expect="${entry%%|*}"
expect="${expect// /}"
rest="${entry#*|}"
name="${rest%%|*}"
args="${rest#*|}"
name="${entry%%|*}"
args="${entry#*|}"
printf '== [%-5s] %-22s :: %s\n' "$expect" "$name" "$args"
printf '== %-22s :: %s\n' "$name" "$args"
set +e
"$EXE" $COMMON $args > /tmp/swa_test_out.$$ 2>&1
ret=$?
set -e
if [ "$expect" = "GREEN" ]; then
if [ $ret -eq 0 ]; then
echo " PASS (as expected)"
n_green_pass=$((n_green_pass + 1))
else
echo " REGRESSION: expected GREEN but failed (rc=$ret). Tail of output:"
tail -3 /tmp/swa_test_out.$$ | sed 's/^/ /'
n_green_fail=$((n_green_fail + 1))
fi
if [ $ret -eq 0 ]; then
echo " PASS"
n_pass=$((n_pass + 1))
else
if [ $ret -ne 0 ]; then
echo " FAIL (RED, as expected)"
n_red_fail=$((n_red_fail + 1))
else
echo " UNEXPECTED PASS: SWA support may have landed. Move this test to GREEN."
n_red_pass=$((n_red_pass + 1))
fi
echo " FAIL (rc=$ret). Tail of output:"
tail -3 /tmp/swa_test_out.$$ | sed 's/^/ /'
n_fail=$((n_fail + 1))
fi
rm -f /tmp/swa_test_out.$$
done
echo
echo "Summary:"
printf ' GREEN passed (good) : %d\n' $n_green_pass
printf ' GREEN failed (REGRESSION) : %d\n' $n_green_fail
printf ' RED failed (expected today) : %d\n' $n_red_fail
printf ' RED passed (flip to GREEN now) : %d\n' $n_red_pass
printf ' passed : %d\n' $n_pass
printf ' failed : %d\n' $n_fail
# Exit code = number of unexpected outcomes.
exit $((n_green_fail + n_red_pass))
exit $n_fail

View File

@@ -99,16 +99,32 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
const stream_config& config)
{
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
// SWA is only when masking AND at least one window edge is finite. Causal
// (left=-1, right=0) keeps is_local=false and uses the existing instances.
// Real SWA = "at least one non-trivial window edge". Plain causal lives at
// (left=-1, right=0); without this guard it would hit the IsLocal=true path
// and fail for shape tiers where we have not (yet) instantiated local kernels.
// left >= 0 : finite look-back (e.g. causal SWA, dense SWA, diagonal-only)
// right > 0 : finite look-ahead (bidirectional SWA, anti-causal SWA)
// Note "right >= 0" would mis-classify plain causal (right=0) as SWA.
const bool is_local =
is_mask && (args.window_size_left >= 0 || args.window_size_right >= 0);
is_mask && (args.window_size_left >= 0 || args.window_size_right > 0);
auto tier = select_tile_tier(args);
// For now SWA instances only exist at the large prefill tier (the dispatcher's
// final `else` branch — 8 warps, kBlockM=256). Forcing the largest tier for
// SWA keeps dispatch correct without proliferating instance combinations;
// perf for SWA-on-decode-shapes can be revisited later.
if(is_local) tier = tile_tier::large;
// SWA instances currently only exist at the large prefill tier (kBlockM=256,
// 8 warps). Each requires args.page_blk_size >= kBlockN of the instance —
// otherwise the kernel hits a device-side `kv_page_size_in_blocks >= 1`
// assertion. When SWA is requested on an unsupported (shape, page_blk_size)
// pair we return {false, 0} so the caller (e.g. _try_ck_unified_attention)
// can fall back to a backend that handles it (Triton). Falling through to
// the IsLocal=false path would silently ignore window_size_left and produce
// wrong outputs, so we reject explicitly.
if(is_local)
{
const bool d128_mha = (args.hdim == 128 && args.num_queries_per_kv == 1);
const bool d64_gqa8 = (args.hdim == 64 && args.num_queries_per_kv == 8);
const index_t kBN_req = d128_mha ? 32 : (d64_gqa8 ? 64 : 0);
if(kBN_req == 0 || args.page_blk_size < kBN_req)
return {false, 0.f};
tier = tile_tier::large;
}
// d128, MHA (num_queries_per_kv == 1)
if(args.hdim == 128 && args.num_queries_per_kv == 1)