From f438cef2861b569c5019d8d5f03ff537758b2aba Mon Sep 17 00:00:00 2001 From: Damien Lejeune Date: Fri, 8 May 2026 09:30:56 +0000 Subject: [PATCH] Add smoke tests for SWA edge cases and performance gating --- .../script/edge_test_swa.sh | 91 ++++++++++++++ .../script/perf_test_swa.sh | 111 ++++++++++++++++++ .../script/smoke_test_swa.sh | 80 +++++-------- .../unified_attention.cpp | 32 +++-- 4 files changed, 254 insertions(+), 60 deletions(-) create mode 100755 example/ck_tile/42_unified_attention/script/edge_test_swa.sh create mode 100755 example/ck_tile/42_unified_attention/script/perf_test_swa.sh diff --git a/example/ck_tile/42_unified_attention/script/edge_test_swa.sh b/example/ck_tile/42_unified_attention/script/edge_test_swa.sh new file mode 100755 index 0000000000..221934e858 --- /dev/null +++ b/example/ck_tile/42_unified_attention/script/edge_test_swa.sh @@ -0,0 +1,91 @@ +#!/bin/bash +# edge_test_swa.sh - Numerical edge cases for SWA in CK-tile unified attention. +# +# Tests the corner cases that random-shape sweeps in smoke_test_swa.sh miss: +# - window=1 (every Q row attends to its own K position only) +# - window > seq_k with right=0 (degenerates to plain causal) +# - explicit b:0,0 (alternative spelling of diagonal-only) +# - decode shapes (q=1, kv>>1) — exercises the SWA path on a single-token Q +# (today this is dispatched to the large-tier kernel via the +# `if(is_local) tier = tile_tier::large` hack; correctness should hold even +# though it's wasteful) +# +# Same convention as smoke_test_swa.sh: every test must pass, exit code is the +# number of failures. + +set -uo pipefail + +EXE_NAME=tile_example_unified_attention +EXE="${EXE:-$(find . -name $EXE_NAME -type f | head -n 1)}" +if [ -z "${EXE:-}" ] || [ ! -x "$EXE" ]; then + echo "ERROR: $EXE_NAME not found. Set EXE=/path/to/$EXE_NAME or run from build dir." >&2 + exit 2 +fi + +# Same deterministic fixture as smoke_test_swa.sh. +COMMON="-prec=bf16 -seed=17 -verify=1 -warmup=0 -repeat=1 -varlen=0 -nb=1024 -page_blk_size=128" + +BASELINE_A="-d=128 -h_k=8 -nqpkv=1 -b=4 -s=512 -s_k=512 -query_lens=128,128,128,128 -kv_lens=128,128,128,128" +BASELINE_B="-d=64 -h_k=1 -nqpkv=8 -b=4 -s=512 -s_k=512 -query_lens=400,256,512,128 -kv_lens=400,256,512,128" + +# Decode-shape fixtures (q=1, kv=512). SWA dispatcher forces tile_tier::large so +# these go through the prefill kernel even though they're decode shapes — the +# point of these tests is that the large-tier IsLocal=true kernel still produces +# correct numerics on a single-token query. +DECODE_A="-d=128 -h_k=8 -nqpkv=1 -b=4 -s=1 -s_k=512 -query_lens=1,1,1,1 -kv_lens=512,512,512,512" +DECODE_B="-d=64 -h_k=1 -nqpkv=8 -b=4 -s=1 -s_k=512 -query_lens=1,1,1,1 -kv_lens=512,512,512,512" + +TESTS=( + # Edge 1: window=1 — diagonal-only attention. Smallest non-zero window. + # xb:1 decodes to left=0, right=0 via window/2 split. + "baseA xb:1 |$BASELINE_A -mask=xb:1" + "baseB xb:1 |$BASELINE_B -mask=xb:1" + + # Edge 2: window > seq_k with right=0 — IsLocal=true kernel must still + # produce the same answer as the IsLocal=false causal kernel + # (verified independently against the reference). + "baseA b:8192,0 |$BASELINE_A -mask=b:8192,0" + "baseB b:8192,0 |$BASELINE_B -mask=b:8192,0" + + # Edge 3: alternative diagonal-only spelling via explicit b:0,0. + # Must produce identical numerics to xb:1 above. + "baseA b:0,0 |$BASELINE_A -mask=b:0,0" + "baseB b:0,0 |$BASELINE_B -mask=b:0,0" + + # Edge 4: decode shapes (single-token query). The SWA mask trims the K range + # to a 64-wide window at the bottom-right corner of the (1, 512) + # attention matrix, so most of the kv tail is masked out. + "decode q=1 d128 xb:64|$DECODE_A -mask=xb:64" + "decode q=1 d64 xb:64|$DECODE_B -mask=xb:64" +) + +n_pass=0 +n_fail=0 + +for entry in "${TESTS[@]}"; do + name="${entry%%|*}" + args="${entry#*|}" + + printf '== %-26s :: %s\n' "$name" "$args" + set +e + "$EXE" $COMMON $args > /tmp/swa_edge_out.$$ 2>&1 + ret=$? + set -e + + if [ $ret -eq 0 ]; then + echo " PASS" + n_pass=$((n_pass + 1)) + else + echo " FAIL (rc=$ret). Tail of output:" + tail -3 /tmp/swa_edge_out.$$ | sed 's/^/ /' + n_fail=$((n_fail + 1)) + fi + rm -f /tmp/swa_edge_out.$$ +done + +echo +echo "Summary:" +printf ' passed : %d\n' $n_pass +printf ' failed : %d\n' $n_fail + +exit $n_fail diff --git a/example/ck_tile/42_unified_attention/script/perf_test_swa.sh b/example/ck_tile/42_unified_attention/script/perf_test_swa.sh new file mode 100755 index 0000000000..03c4557270 --- /dev/null +++ b/example/ck_tile/42_unified_attention/script/perf_test_swa.sh @@ -0,0 +1,111 @@ +#!/bin/bash +# perf_test_swa.sh - Perf gating test for SWA in CK-tile unified attention. +# +# Asserts that the SWA KV-block iteration clip (the unified_attention_kernel.hpp +# "Sliding-window-attention: tighten the KV-block iteration..." block) is +# actually firing and skipping out-of-window KV blocks. +# +# Strategy: on a long-context prefill shape (kv=8192, q=128) with a small SWA +# window (128), the kernel should iterate ~3 KV sub-blocks per Q-tile instead +# of ~128 for plain causal. We assert >= MIN_SPEEDUP wall-clock speedup. +# +# Measured speedup on MI350 today (gfx950) is 12-20x; we assert 5x to leave +# generous headroom for other GPUs / contention while still catching a Step D +# regression that re-iterates the full KV. +# +# Run with: +# ./perf_test_swa.sh +# +# Exit code: +# 0 = SWA met the speedup threshold for both shape families. +# 1 = SWA failed to meet the threshold somewhere. +# 2 = environment error (binary not found, parse failure, etc.) + +set -uo pipefail + +EXE_NAME=tile_example_unified_attention +EXE="${EXE:-$(find . -name $EXE_NAME -type f | head -n 1)}" +if [ -z "${EXE:-}" ] || [ ! -x "$EXE" ]; then + echo "ERROR: $EXE_NAME not found. Set EXE=/path/to/$EXE_NAME or run from build dir." >&2 + exit 2 +fi + +# Speedup threshold; the actual measurement is 10-20x, so 5x is a generous +# regression guard. +MIN_SPEEDUP="${MIN_SPEEDUP:-5.0}" + +# verify=0 because we trust smoke_test_swa.sh / edge_test_swa.sh for numerics +# and want time_kernel_=true to actually measure clock time, not host-ref work. +COMMON="-prec=bf16 -seed=17 -verify=0 -warmup=5 -repeat=20 -varlen=0 -nb=1024 -page_blk_size=128" + +# Long-context prefill: kv=8192 with a 128-token query. This is the regime +# where SWA (window=128) is most lopsided vs causal (full lower triangle). +SHAPE_A="-d=128 -h_k=8 -nqpkv=1 -b=2 -s=128 -s_k=8192 -query_lens=128,128 -kv_lens=8192,8192" +SHAPE_B="-d=64 -h_k=1 -nqpkv=8 -b=2 -s=128 -s_k=8192 -query_lens=128,128 -kv_lens=8192,8192" + +# Parse "<...>, ms, <...>" out of the kernel summary line. +# We grep -oP a single match to avoid partial reads of a large output. +extract_ms() { + grep -oP '\d+\.\d+(?= ms,)' | head -n 1 +} + +run_one() { + local label="$1"; shift + local out + out=$("$EXE" $COMMON "$@" 2>&1) + local ms + ms=$(echo "$out" | extract_ms) + if [ -z "$ms" ]; then + echo "ERROR: failed to extract ms from output of '$label'" >&2 + echo "$out" | tail -10 >&2 + exit 2 + fi + printf '%s\n' "$ms" +} + +# Returns 0 if "$1 / $2 >= $MIN_SPEEDUP", 1 otherwise. +check_speedup() { + awk -v c="$1" -v s="$2" -v m="$MIN_SPEEDUP" \ + 'BEGIN { sp = c / s; if (sp >= m) exit 0; else exit 1 }' +} + +n_fail=0 +overall_status=0 + +run_one_shape() { + local shape_name="$1" + local shape_args="$2" + + echo "=== $shape_name ===" + local t_causal t_swa + t_causal=$(run_one "$shape_name causal" $shape_args -mask=b) + t_swa=$(run_one "$shape_name swa" $shape_args -mask=xb:128) + + local speedup + speedup=$(awk -v c="$t_causal" -v s="$t_swa" 'BEGIN { printf "%.2f", c / s }') + + printf ' causal : %8s ms\n' "$t_causal" + printf ' swa xb:128 : %8s ms\n' "$t_swa" + printf ' speedup : %sx (threshold %sx)\n' "$speedup" "$MIN_SPEEDUP" + + if check_speedup "$t_causal" "$t_swa"; then + echo " PASS" + else + echo " FAIL: SWA was not >= ${MIN_SPEEDUP}x faster than causal." + echo " Most likely culprit: Step D (KV-block iteration clip in" + echo " unified_attention_kernel.hpp) was disabled or regressed," + echo " leaving the SWA path iterating the full KV like causal." + n_fail=$((n_fail + 1)) + overall_status=1 + fi + echo +} + +run_one_shape "d=128 MHA, q=128, kv=8192" "$SHAPE_A" +run_one_shape "d=64 GQA-8 (h_k=1), q=128, kv=8192" "$SHAPE_B" + +if [ $overall_status -eq 0 ]; then + echo "All perf gates passed." +fi + +exit $overall_status diff --git a/example/ck_tile/42_unified_attention/script/smoke_test_swa.sh b/example/ck_tile/42_unified_attention/script/smoke_test_swa.sh index dca20de78a..12d05f411f 100755 --- a/example/ck_tile/42_unified_attention/script/smoke_test_swa.sh +++ b/example/ck_tile/42_unified_attention/script/smoke_test_swa.sh @@ -1,20 +1,15 @@ #!/bin/bash -# smoke_test_swa.sh - RED tests for Sliding Window Attention (SWA) -# in the CK-tile unified attention kernel. +# smoke_test_swa.sh - Sliding Window Attention (SWA) smoke tests for the +# CK-tile unified attention kernel. # -# Each test entry is "EXPECT|EXTRA_ARGS" where EXPECT is GREEN or RED. -# GREEN: the test must currently pass; failing it is a regression. -# RED: the test must currently fail; passing it means SWA support landed -# and the test should be moved to GREEN. +# Each test entry is "NAME|EXTRA_ARGS"; every test must pass against the host +# reference. Failure exit code is the number of failed tests. # # Run with: # ./smoke_test_swa.sh -# -# Exit code is the number of unexpected outcomes (0 = all matched expectation). set -uo pipefail -SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) EXE_NAME=tile_example_unified_attention EXE="${EXE:-$(find . -name $EXE_NAME -type f | head -n 1)}" if [ -z "${EXE:-}" ] || [ ! -x "$EXE" ]; then @@ -34,67 +29,48 @@ BASELINE_A="-d=128 -h_k=8 -nqpkv=1 -b=4 -s=512 -s_k=512 -query_lens=128,128,128, BASELINE_B="-d=64 -h_k=1 -nqpkv=8 -b=4 -s=512 -s_k=512 -query_lens=400,256,512,128 -kv_lens=400,256,512,128" TESTS=( - # Causal regression guards (must pass). - "GREEN|baseA causal |$BASELINE_A -mask=b" - "GREEN|baseB causal |$BASELINE_B -mask=b" + # Causal regression guards. + "baseA causal |$BASELINE_A -mask=b" + "baseB causal |$BASELINE_B -mask=b" - # SWA via xformer-style window: kernel is now expected to honor the SWA - # window on both axes (per-pixel mask + KV-block iteration clip). - "GREEN|baseA xb:64 |$BASELINE_A -mask=xb:64" - "GREEN|baseA xb:128 |$BASELINE_A -mask=xb:128" - "GREEN|baseB xb:64 |$BASELINE_B -mask=xb:64" - "GREEN|baseB xb:128 |$BASELINE_B -mask=xb:128" + # SWA via xformer-style window: per-pixel mask + KV-block iteration clip. + "baseA xb:64 |$BASELINE_A -mask=xb:64" + "baseA xb:128 |$BASELINE_A -mask=xb:128" + "baseB xb:64 |$BASELINE_B -mask=xb:64" + "baseB xb:128 |$BASELINE_B -mask=xb:128" # SWA via FA-style explicit left/right window. - "GREEN|baseA b:64,0 |$BASELINE_A -mask=b:64,0" - "GREEN|baseB b:64,0 |$BASELINE_B -mask=b:64,0" + "baseA b:64,0 |$BASELINE_A -mask=b:64,0" + "baseB b:64,0 |$BASELINE_B -mask=b:64,0" ) -n_green_pass=0 -n_green_fail=0 # regressions -n_red_pass=0 # unexpected SWA passes (move to GREEN) -n_red_fail=0 # expected RED +n_pass=0 +n_fail=0 for entry in "${TESTS[@]}"; do - expect="${entry%%|*}" - expect="${expect// /}" - rest="${entry#*|}" - name="${rest%%|*}" - args="${rest#*|}" + name="${entry%%|*}" + args="${entry#*|}" - printf '== [%-5s] %-22s :: %s\n' "$expect" "$name" "$args" + printf '== %-22s :: %s\n' "$name" "$args" set +e "$EXE" $COMMON $args > /tmp/swa_test_out.$$ 2>&1 ret=$? set -e - if [ "$expect" = "GREEN" ]; then - if [ $ret -eq 0 ]; then - echo " PASS (as expected)" - n_green_pass=$((n_green_pass + 1)) - else - echo " REGRESSION: expected GREEN but failed (rc=$ret). Tail of output:" - tail -3 /tmp/swa_test_out.$$ | sed 's/^/ /' - n_green_fail=$((n_green_fail + 1)) - fi + if [ $ret -eq 0 ]; then + echo " PASS" + n_pass=$((n_pass + 1)) else - if [ $ret -ne 0 ]; then - echo " FAIL (RED, as expected)" - n_red_fail=$((n_red_fail + 1)) - else - echo " UNEXPECTED PASS: SWA support may have landed. Move this test to GREEN." - n_red_pass=$((n_red_pass + 1)) - fi + echo " FAIL (rc=$ret). Tail of output:" + tail -3 /tmp/swa_test_out.$$ | sed 's/^/ /' + n_fail=$((n_fail + 1)) fi rm -f /tmp/swa_test_out.$$ done echo echo "Summary:" -printf ' GREEN passed (good) : %d\n' $n_green_pass -printf ' GREEN failed (REGRESSION) : %d\n' $n_green_fail -printf ' RED failed (expected today) : %d\n' $n_red_fail -printf ' RED passed (flip to GREEN now) : %d\n' $n_red_pass +printf ' passed : %d\n' $n_pass +printf ' failed : %d\n' $n_fail -# Exit code = number of unexpected outcomes. -exit $((n_green_fail + n_red_pass)) +exit $n_fail diff --git a/example/ck_tile/42_unified_attention/unified_attention.cpp b/example/ck_tile/42_unified_attention/unified_attention.cpp index 47413b4fed..91056e7c0d 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/unified_attention.cpp @@ -99,16 +99,32 @@ std::pair unified_attention(const unified_attention_args& args, const stream_config& config) { const bool is_mask = (args.mask_type != static_cast(mask_enum::no_mask)); - // SWA is only when masking AND at least one window edge is finite. Causal - // (left=-1, right=0) keeps is_local=false and uses the existing instances. + // Real SWA = "at least one non-trivial window edge". Plain causal lives at + // (left=-1, right=0); without this guard it would hit the IsLocal=true path + // and fail for shape tiers where we have not (yet) instantiated local kernels. + // left >= 0 : finite look-back (e.g. causal SWA, dense SWA, diagonal-only) + // right > 0 : finite look-ahead (bidirectional SWA, anti-causal SWA) + // Note "right >= 0" would mis-classify plain causal (right=0) as SWA. const bool is_local = - is_mask && (args.window_size_left >= 0 || args.window_size_right >= 0); + is_mask && (args.window_size_left >= 0 || args.window_size_right > 0); auto tier = select_tile_tier(args); - // For now SWA instances only exist at the large prefill tier (the dispatcher's - // final `else` branch — 8 warps, kBlockM=256). Forcing the largest tier for - // SWA keeps dispatch correct without proliferating instance combinations; - // perf for SWA-on-decode-shapes can be revisited later. - if(is_local) tier = tile_tier::large; + // SWA instances currently only exist at the large prefill tier (kBlockM=256, + // 8 warps). Each requires args.page_blk_size >= kBlockN of the instance — + // otherwise the kernel hits a device-side `kv_page_size_in_blocks >= 1` + // assertion. When SWA is requested on an unsupported (shape, page_blk_size) + // pair we return {false, 0} so the caller (e.g. _try_ck_unified_attention) + // can fall back to a backend that handles it (Triton). Falling through to + // the IsLocal=false path would silently ignore window_size_left and produce + // wrong outputs, so we reject explicitly. + if(is_local) + { + const bool d128_mha = (args.hdim == 128 && args.num_queries_per_kv == 1); + const bool d64_gqa8 = (args.hdim == 64 && args.num_queries_per_kv == 8); + const index_t kBN_req = d128_mha ? 32 : (d64_gqa8 ? 64 : 0); + if(kBN_req == 0 || args.page_blk_size < kBN_req) + return {false, 0.f}; + tier = tile_tier::large; + } // d128, MHA (num_queries_per_kv == 1) if(args.hdim == 128 && args.num_queries_per_kv == 1)