Adding SWA decode dispatcher to support GPT-OSS shape + update smoke test

This commit is contained in:
Damien Lejeune
2026-05-08 14:38:16 +00:00
parent e36693c4dc
commit b686143624
6 changed files with 188 additions and 27 deletions

View File

@@ -0,0 +1,23 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
// d64 GQA-8 medium decode tier with BlockSize=32 (kBlockM=128, kBlockQ=16),
// IsMasking=true, IsLocal=true. Targets GPT-OSS short-prefill SWA shapes
// (max_seqlen_q in [257,1024], page_blk_size=32, GQA-8).
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16,
/*IsMasking=*/true,
/*HeadSize=*/64,
/*BlockM=*/128,
/*NumQPerKV=*/8,
/*BlockSize=*/32,
/*IsLocal=*/true>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,23 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
// d64 GQA-8 tiny+bs32 decode tier (kBlockM=32, kBlockQ=4, BlockSize=32),
// IsMasking=true, IsLocal=true. Targets GPT-OSS decode shapes
// (q=1, page_blk_size=32, GQA-8) with sliding-window-attention.
using kernel_traits =
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::bf16,
/*IsMasking=*/true,
/*HeadSize=*/64,
/*BlockM=*/32,
/*NumQPerKV=*/8,
/*BlockSize=*/32,
/*IsLocal=*/true>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,22 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
// d64 GQA-8 medium decode tier with BlockSize=32 (kBlockM=128, kBlockQ=16),
// IsMasking=true, IsLocal=true. fp16 sibling of the bf16 instance.
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16,
/*IsMasking=*/true,
/*HeadSize=*/64,
/*BlockM=*/128,
/*NumQPerKV=*/8,
/*BlockSize=*/32,
/*IsLocal=*/true>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,23 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
// d64 GQA-8 tiny+bs32 decode tier (kBlockM=32, kBlockQ=4, BlockSize=32),
// IsMasking=true, IsLocal=true. fp16 sibling of the bf16 instance used by
// the GPT-OSS decode SWA path.
using kernel_traits =
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::fp16,
/*IsMasking=*/true,
/*HeadSize=*/64,
/*BlockM=*/32,
/*NumQPerKV=*/8,
/*BlockSize=*/32,
/*IsLocal=*/true>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
} // namespace ck_tile