mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 06:44:36 +00:00
Adding SWA decode dispatcher to support GPT-OSS shape + update smoke test
This commit is contained in:
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// d64 GQA-8 medium decode tier with BlockSize=32 (kBlockM=128, kBlockQ=16),
|
||||
// IsMasking=true, IsLocal=true. Targets GPT-OSS short-prefill SWA shapes
|
||||
// (max_seqlen_q in [257,1024], page_blk_size=32, GQA-8).
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16,
|
||||
/*IsMasking=*/true,
|
||||
/*HeadSize=*/64,
|
||||
/*BlockM=*/128,
|
||||
/*NumQPerKV=*/8,
|
||||
/*BlockSize=*/32,
|
||||
/*IsLocal=*/true>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// d64 GQA-8 tiny+bs32 decode tier (kBlockM=32, kBlockQ=4, BlockSize=32),
|
||||
// IsMasking=true, IsLocal=true. Targets GPT-OSS decode shapes
|
||||
// (q=1, page_blk_size=32, GQA-8) with sliding-window-attention.
|
||||
using kernel_traits =
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::bf16,
|
||||
/*IsMasking=*/true,
|
||||
/*HeadSize=*/64,
|
||||
/*BlockM=*/32,
|
||||
/*NumQPerKV=*/8,
|
||||
/*BlockSize=*/32,
|
||||
/*IsLocal=*/true>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,22 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// d64 GQA-8 medium decode tier with BlockSize=32 (kBlockM=128, kBlockQ=16),
|
||||
// IsMasking=true, IsLocal=true. fp16 sibling of the bf16 instance.
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16,
|
||||
/*IsMasking=*/true,
|
||||
/*HeadSize=*/64,
|
||||
/*BlockM=*/128,
|
||||
/*NumQPerKV=*/8,
|
||||
/*BlockSize=*/32,
|
||||
/*IsLocal=*/true>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// d64 GQA-8 tiny+bs32 decode tier (kBlockM=32, kBlockQ=4, BlockSize=32),
|
||||
// IsMasking=true, IsLocal=true. fp16 sibling of the bf16 instance used by
|
||||
// the GPT-OSS decode SWA path.
|
||||
using kernel_traits =
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::fp16,
|
||||
/*IsMasking=*/true,
|
||||
/*HeadSize=*/64,
|
||||
/*BlockM=*/32,
|
||||
/*NumQPerKV=*/8,
|
||||
/*BlockSize=*/32,
|
||||
/*IsLocal=*/true>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user