mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
Adding SWA decode dispatcher to support GPT-OSS shape + update smoke test
This commit is contained in:
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// d64 GQA-8 medium decode tier with BlockSize=32 (kBlockM=128, kBlockQ=16),
|
||||
// IsMasking=true, IsLocal=true. Targets GPT-OSS short-prefill SWA shapes
|
||||
// (max_seqlen_q in [257,1024], page_blk_size=32, GQA-8).
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16,
|
||||
/*IsMasking=*/true,
|
||||
/*HeadSize=*/64,
|
||||
/*BlockM=*/128,
|
||||
/*NumQPerKV=*/8,
|
||||
/*BlockSize=*/32,
|
||||
/*IsLocal=*/true>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// d64 GQA-8 tiny+bs32 decode tier (kBlockM=32, kBlockQ=4, BlockSize=32),
|
||||
// IsMasking=true, IsLocal=true. Targets GPT-OSS decode shapes
|
||||
// (q=1, page_blk_size=32, GQA-8) with sliding-window-attention.
|
||||
using kernel_traits =
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::bf16,
|
||||
/*IsMasking=*/true,
|
||||
/*HeadSize=*/64,
|
||||
/*BlockM=*/32,
|
||||
/*NumQPerKV=*/8,
|
||||
/*BlockSize=*/32,
|
||||
/*IsLocal=*/true>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,22 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// d64 GQA-8 medium decode tier with BlockSize=32 (kBlockM=128, kBlockQ=16),
|
||||
// IsMasking=true, IsLocal=true. fp16 sibling of the bf16 instance.
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16,
|
||||
/*IsMasking=*/true,
|
||||
/*HeadSize=*/64,
|
||||
/*BlockM=*/128,
|
||||
/*NumQPerKV=*/8,
|
||||
/*BlockSize=*/32,
|
||||
/*IsLocal=*/true>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// d64 GQA-8 tiny+bs32 decode tier (kBlockM=32, kBlockQ=4, BlockSize=32),
|
||||
// IsMasking=true, IsLocal=true. fp16 sibling of the bf16 instance used by
|
||||
// the GPT-OSS decode SWA path.
|
||||
using kernel_traits =
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::fp16,
|
||||
/*IsMasking=*/true,
|
||||
/*HeadSize=*/64,
|
||||
/*BlockM=*/32,
|
||||
/*NumQPerKV=*/8,
|
||||
/*BlockSize=*/32,
|
||||
/*IsLocal=*/true>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -5,10 +5,11 @@
|
||||
# - window=1 (every Q row attends to its own K position only)
|
||||
# - window > seq_k with right=0 (degenerates to plain causal)
|
||||
# - explicit b:0,0 (alternative spelling of diagonal-only)
|
||||
# - decode shapes (q=1, kv>>1) — exercises the SWA path on a single-token Q
|
||||
# (today this is dispatched to the large-tier kernel via the
|
||||
# `if(is_local) tier = tile_tier::large` hack; correctness should hold even
|
||||
# though it's wasteful)
|
||||
# - decode shapes (q=1, kv>>1) — exercises the SWA path on a single-token Q.
|
||||
# For page_blk_size>=64 (Edge 4) we route to the large-tier kernel, which is
|
||||
# wasteful but correct. For page_blk_size==32 (Edge 5) we route to the
|
||||
# decode-tier IsLocal=true instances added in Phase 5 — that's the GPT-OSS
|
||||
# production path.
|
||||
#
|
||||
# Same convention as smoke_test_swa.sh: every test must pass, exit code is the
|
||||
# number of failures.
|
||||
@@ -35,6 +36,15 @@ BASELINE_B="-d=64 -h_k=1 -nqpkv=8 -b=4 -s=512 -s_k=512 -query_lens=400,256,512,
|
||||
DECODE_A="-d=128 -h_k=8 -nqpkv=1 -b=4 -s=1 -s_k=512 -query_lens=1,1,1,1 -kv_lens=512,512,512,512"
|
||||
DECODE_B="-d=64 -h_k=1 -nqpkv=8 -b=4 -s=1 -s_k=512 -query_lens=1,1,1,1 -kv_lens=512,512,512,512"
|
||||
|
||||
# Decode + page_blk_size=32 fixtures for d=64 GQA-8. These exercise the NEW
|
||||
# decode-tier IsLocal=true instances added in Phase 5:
|
||||
# - DECODE_BS32_Q1 (q=1) → tiny+bs32 local (kBlockM=32, kBlockQ=4)
|
||||
# - DECODE_BS32_QM (q in [256,1024]) → medium+bs32 local (kBlockM=128, kBlockQ=16)
|
||||
# Use the GPT-OSS-shaped window (left=127, right=0) to mirror the production
|
||||
# workload that motivated Phase 5.
|
||||
DECODE_BS32_Q1="-d=64 -h_k=1 -nqpkv=8 -b=4 -s=1 -s_k=512 -query_lens=1,1,1,1 -kv_lens=512,512,512,512"
|
||||
DECODE_BS32_QM="-d=64 -h_k=1 -nqpkv=8 -b=2 -s=1024 -s_k=1024 -query_lens=1024,512 -kv_lens=1024,512"
|
||||
|
||||
TESTS=(
|
||||
# Edge 1: window=1 — diagonal-only attention. Smallest non-zero window.
|
||||
# xb:1 decodes to left=0, right=0 via window/2 split.
|
||||
@@ -55,8 +65,17 @@ TESTS=(
|
||||
# Edge 4: decode shapes (single-token query). The SWA mask trims the K range
|
||||
# to a 64-wide window at the bottom-right corner of the (1, 512)
|
||||
# attention matrix, so most of the kv tail is masked out.
|
||||
"decode q=1 d128 xb:64|$DECODE_A -mask=xb:64"
|
||||
"decode q=1 d64 xb:64|$DECODE_B -mask=xb:64"
|
||||
"decode q=1 d128 xb:64 |$DECODE_A -mask=xb:64"
|
||||
"decode q=1 d64 xb:64 |$DECODE_B -mask=xb:64"
|
||||
|
||||
# Edge 5 (Phase 5): GPT-OSS-shaped d64 GQA-8 SWA on page_blk_size=32. These
|
||||
# MUST hit the new decode-tier IsLocal=true instances; if a regression takes
|
||||
# them back to the bs64-only fallback they fail with "no matching kernel
|
||||
# instance" or wrong numerics.
|
||||
"decode q=1 d64 bs32 b:127,0 |$DECODE_BS32_Q1 -page_blk_size=32 -mask=b:127,0"
|
||||
"decode q=1 d64 bs32 xb:128 |$DECODE_BS32_Q1 -page_blk_size=32 -mask=xb:128"
|
||||
"shortpf d64 bs32 b:127,0 |$DECODE_BS32_QM -page_blk_size=32 -mask=b:127,0"
|
||||
"shortpf d64 bs32 xb:128 |$DECODE_BS32_QM -page_blk_size=32 -mask=xb:128"
|
||||
)
|
||||
|
||||
n_pass=0
|
||||
@@ -66,7 +85,7 @@ for entry in "${TESTS[@]}"; do
|
||||
name="${entry%%|*}"
|
||||
args="${entry#*|}"
|
||||
|
||||
printf '== %-26s :: %s\n' "$name" "$args"
|
||||
printf '== %-32s :: %s\n' "$name" "$args"
|
||||
set +e
|
||||
"$EXE" $COMMON $args > /tmp/swa_edge_out.$$ 2>&1
|
||||
ret=$?
|
||||
|
||||
@@ -34,6 +34,21 @@ std::ostream& operator<<(std::ostream& stream,
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
// SWA-aware decode dispatchers for bs32. These mirror the non-local *_BS32 macros
|
||||
// but flip IsLocal=true on the 7th template arg, so the kernel uses the
|
||||
// sliding-window mask AND the per-Q-tile KV-block iteration clipping.
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LOCAL(DType, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_kernel_traits<DType, /*IsMasking=*/true, HSize, BM, NQPKV, 32, /*IsLocal=*/true>; \
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL(DType, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, /*IsMasking=*/true, HSize, BM, NQPKV, 32, /*IsLocal=*/true>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
// Dispatch macros for three tile tiers (default block_size).
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
@@ -108,22 +123,52 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
const bool is_local =
|
||||
is_mask && (args.window_size_left >= 0 || args.window_size_right > 0);
|
||||
auto tier = select_tile_tier(args);
|
||||
// SWA instances currently only exist at the large prefill tier (kBlockM=256,
|
||||
// 8 warps). Each requires args.page_blk_size >= kBlockN of the instance —
|
||||
// otherwise the kernel hits a device-side `kv_page_size_in_blocks >= 1`
|
||||
// assertion. When SWA is requested on an unsupported (shape, page_blk_size)
|
||||
// pair we return {false, 0} so the caller (e.g. _try_ck_unified_attention)
|
||||
// can fall back to a backend that handles it (Triton). Falling through to
|
||||
// the IsLocal=false path would silently ignore window_size_left and produce
|
||||
// SWA-instance availability matrix (IsLocal=true). Anything not listed here
|
||||
// returns {false, 0} so the caller (e.g. _try_ck_unified_attention) falls
|
||||
// back to a backend that handles it (Triton). Falling through to the
|
||||
// IsLocal=false path would silently ignore window_size_left and produce
|
||||
// wrong outputs, so we reject explicitly.
|
||||
//
|
||||
// shape | page_blk_size | tier we route to | instance
|
||||
// ---------------+---------------+-----------------------------+---------------------
|
||||
// d128 MHA | >= 32 | tile_tier::large | d128_*_mask_local
|
||||
// d64 GQA-8 | >= 64 | tile_tier::large | d64_*_mask_gqa8_local
|
||||
// d64 GQA-8 | == 32, tiny | tile_tier::tiny | d64_*_mask_gqa8_bs32_narrow_local
|
||||
// d64 GQA-8 | == 32, med | tile_tier::medium | d64_*_mask_gqa8_bs32_decode_local
|
||||
// (small+bs32 SWA has no instance yet -> Triton fallback; GPT-OSS shows zero
|
||||
// such shapes in practice. Bumping it to medium is plausible but wastes a
|
||||
// full kBlockM=128 tile when only kBlockM=64 is needed -- revisit if the
|
||||
// workload changes.)
|
||||
if(is_local)
|
||||
{
|
||||
const bool d128_mha = (args.hdim == 128 && args.num_queries_per_kv == 1);
|
||||
const bool d64_gqa8 = (args.hdim == 64 && args.num_queries_per_kv == 8);
|
||||
const index_t kBN_req = d128_mha ? 32 : (d64_gqa8 ? 64 : 0);
|
||||
if(kBN_req == 0 || args.page_blk_size < kBN_req)
|
||||
const bool d128_mha = (args.hdim == 128 && args.num_queries_per_kv == 1);
|
||||
const bool d64_gqa8 = (args.hdim == 64 && args.num_queries_per_kv == 8);
|
||||
if(d128_mha)
|
||||
{
|
||||
if(args.page_blk_size < 32) return {false, 0.f};
|
||||
tier = tile_tier::large;
|
||||
}
|
||||
else if(d64_gqa8)
|
||||
{
|
||||
if(args.page_blk_size >= 64)
|
||||
{
|
||||
tier = tile_tier::large;
|
||||
}
|
||||
else if(args.page_blk_size >= 32)
|
||||
{
|
||||
if(tier != tile_tier::tiny && tier != tile_tier::medium)
|
||||
return {false, 0.f};
|
||||
// keep selected tier as-is
|
||||
}
|
||||
else
|
||||
{
|
||||
return {false, 0.f};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return {false, 0.f};
|
||||
tier = tile_tier::large;
|
||||
}
|
||||
}
|
||||
|
||||
// d128, MHA (num_queries_per_kv == 1)
|
||||
@@ -155,13 +200,15 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
// Avoids 1-warp race condition; 2x less waste than small tier.
|
||||
if(args.data_type == unified_attention_args::data_type_enum::fp16)
|
||||
{
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, false, 64, 32, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, true, 64, 32, 8)
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, false, 64, 32, 8)
|
||||
else if(is_local) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL(unified_attention_args::data_type_enum::fp16, 64, 32, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, true, 64, 32, 8)
|
||||
}
|
||||
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
|
||||
{
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, false, 64, 32, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, true, 64, 32, 8)
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, false, 64, 32, 8)
|
||||
else if(is_local) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL(unified_attention_args::data_type_enum::bf16, 64, 32, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, true, 64, 32, 8)
|
||||
}
|
||||
} else {
|
||||
// bs64 tiny: 1 warp, 16x16 MFMA, kBlockM=16, kBlockQ=2.
|
||||
@@ -205,8 +252,9 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
if(args.data_type == unified_attention_args::data_type_enum::fp16)
|
||||
{
|
||||
if(use_bs32) {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8)
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8)
|
||||
else if(is_local) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LOCAL(unified_attention_args::data_type_enum::fp16, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8)
|
||||
} else {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8)
|
||||
@@ -215,8 +263,9 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
|
||||
{
|
||||
if(use_bs32) {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8)
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8)
|
||||
else if(is_local) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LOCAL(unified_attention_args::data_type_enum::bf16, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8)
|
||||
} else {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8)
|
||||
@@ -248,6 +297,8 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
return std::make_pair(false, -1.f);
|
||||
}
|
||||
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LOCAL
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32
|
||||
|
||||
Reference in New Issue
Block a user