Delete CK-UA bs32 variant family

The bs32 variants existed because pre-fix the pipeline required
kBlockN <= page_size, so page_size=32 forced a kBlockN=32 kernel
family. The multi-page-tile fix (commit 473869aba) lifted that
constraint and made kBlockN compile-time-independent of the runtime
page size, so the bs32 family is now redundant: every non-bs32 variant
is correct for any page_size.

This was validated in advance by forcing use_bs32=false in the
dispatcher and running the full correctness suite -- 236/240, identical
to baseline (the 4 remaining failures are the pre-existing int32-
overflow case, orthogonal).

Removes:
  * 16 instances/unified_attention_*_bs32_*.cpp files
  * unified_attention_decode_bs32_kernel_traits in unified_attention_impl.hpp
  * 3 _BS32 dispatch macros in unified_attention.cpp
  * 3 _p32 entries from the KernelVariant enum
  * 3 dispatch_*_p32 helper functions and their switch cases
  * the page_blk_size branch in select_config (now a pure tile-tier ladder)

Net: 12 fewer compile units (build time -6s on JIT), 78 fewer dispatcher
lines, and "which kernel runs?" is now driven purely by Q-tile shape.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
juuso-oskari
2026-05-12 09:41:41 +00:00
parent fddb0d21cd
commit 5bd8f73a28
18 changed files with 18 additions and 390 deletions

View File

@@ -1,14 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 128, 8, 32>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 64, 8, 32>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 16, 8, 32>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 32, 8, 32>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 128, 8, 32>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 64, 8, 32>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 16, 8, 32>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 32, 8, 32>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 128, 8, 32>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 64, 8, 32>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 16, 8, 32>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 32, 8, 32>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 128, 8, 32>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 64, 8, 32>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 16, 8, 32>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 32, 8, 32>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
} // namespace ck_tile

View File

@@ -26,10 +26,10 @@ std::ostream& operator<<(std::ostream& stream,
// 1. KernelVariant + select_config(args)
// - KernelVariant is a flat enum of every compiled kernel instance the
// module knows about. Each entry fixes the static knobs (kBlockM,
// warp count, MFMA shape, pipeline policy, optional kBlockN override).
// warp count, MFMA shape, pipeline policy).
// - select_config() is the ONLY place where shape-based runtime
// decisions live. It reads the problem (hdim, num_queries_per_kv,
// page_blk_size, avg_q, max_seqlen_q) and emits a KernelConfig.
// decisions live. It reads (hdim, num_queries_per_kv, avg_q,
// max_seqlen_q) and emits a KernelConfig.
//
// 2. dispatch_<variant>() helpers + the final switch
// - Each KernelVariant has a tiny helper that fans out over the
@@ -38,12 +38,10 @@ std::ostream& operator<<(std::ostream& stream,
// per-variant traits classes are unchanged from before; only the
// selection logic moved.
//
// Phase-1 note: page-size is currently still a static axis in the enum
// (the _p32 suffix marks the variant with kBlockN=32 that was originally
// required when page_size < 64). The multi-page-tile fix in the pipeline
// removed the underlying constraint, so a follow-up commit deletes the
// _p32 (a.k.a. "bs32") family entirely. Doing it in two steps keeps each
// diff easy to bisect against the test suite.
// page_size is intentionally NOT part of this enum. The multi-page-tile
// fix in the pipeline made the compile-time tile-N (kBlockN) independent
// of the runtime page_blk_size, so every variant is correct for any page
// size. Selection is driven purely by Q-tile shape.
// =============================================================================
enum class KernelVariant {
@@ -54,10 +52,7 @@ enum class KernelVariant {
// d=64 GQA-8 (num_queries_per_kv = 8)
prefill_d64_gqa8, // kBlockM=256, 8 warps, 32x32 mfma
decode_d64_gqa8_m128, // kBlockM=128, 4 warps, 32x32 mfma
decode_d64_gqa8_m128_p32, // kBlockM=128, 4 warps, 32x32 mfma, kBlockN=32
decode_d64_gqa8_m64, // kBlockM=64, 2 warps, 32x32 mfma
decode_d64_gqa8_m64_p32, // kBlockM=64, 2 warps, 32x32 mfma, kBlockN=32
decode_d64_gqa8_m32_p32, // kBlockM=32, 2 warps, 16x16 mfma, kBlockN=32
decode_d64_gqa8_m16, // kBlockM=16, 1 warp, 16x16 mfma
};
@@ -114,28 +109,14 @@ KernelConfig select_config(const unified_attention_args& args)
return cfg;
}
// d=64 GQA-8 — full tile-tier ladder, with _p32 variants for the legacy
// kBlockN=32 path used when page_blk_size < 64.
// d=64 GQA-8 — pure tile-tier ladder. page_size has no influence here.
if (args.hdim == 64 && args.num_queries_per_kv == 8)
{
const bool p32 = (args.page_blk_size < 64);
switch (select_tile_tier(args))
{
case tile_tier::tiny:
// p32: 2-warp 16x16 (kBlockM=32) -- avoids the 1-warp+p32 race.
// p64: 1-warp 16x16 (kBlockM=16) -- matches Triton BLOCK_M=16.
cfg.variant = p32 ? KernelVariant::decode_d64_gqa8_m32_p32
: KernelVariant::decode_d64_gqa8_m16;
break;
case tile_tier::small:
cfg.variant = p32 ? KernelVariant::decode_d64_gqa8_m64_p32
: KernelVariant::decode_d64_gqa8_m64;
break;
case tile_tier::medium:
cfg.variant = p32 ? KernelVariant::decode_d64_gqa8_m128_p32
: KernelVariant::decode_d64_gqa8_m128;
break;
case tile_tier::tiny: cfg.variant = KernelVariant::decode_d64_gqa8_m16; break;
case tile_tier::small: cfg.variant = KernelVariant::decode_d64_gqa8_m64; break;
case tile_tier::medium: cfg.variant = KernelVariant::decode_d64_gqa8_m128; break;
}
return cfg;
}
@@ -150,7 +131,7 @@ KernelConfig select_config(const unified_attention_args& args)
// Each DISPATCH_* macro instantiates one (traits, dtype, mask, ...) combo and
// returns. The per-variant helpers below pick the right macro family and fan
// out over (dtype, mask). They look repetitive on purpose: a follow-up commit
// will collapse the 5 traits classes into one templated `kernel_traits<V>`,
// will collapse the 4 traits classes into one templated `kernel_traits<V>`,
// at which point these helpers become one-liners.
// -----------------------------------------------------------------------------
@@ -180,25 +161,6 @@ KernelConfig select_config(const unified_attention_args& args)
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
}
// block_size=32 dispatch macros (6th template arg = 32).
#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(DType, IsMask, HSize, BM, NQPKV) \
{ \
using kernel_traits = unified_attention_decode_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32>; \
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
}
#define DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(DType, IsMask, HSize, BM, NQPKV) \
{ \
using kernel_traits = unified_attention_decode_small_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32>; \
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
}
#define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType, IsMask, HSize, BM, NQPKV) \
{ \
using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32>; \
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
}
namespace {
using DType = unified_attention_args::data_type_enum;
@@ -259,20 +221,6 @@ std::pair<bool, float> dispatch_decode_d64_gqa8_m128(
return {false, -1.f};
}
std::pair<bool, float> dispatch_decode_d64_gqa8_m128_p32(
const unified_attention_args& args, const stream_config& config)
{
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
if (args.data_type == DType::fp16) {
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(DType::fp16, false, 64, 128, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(DType::fp16, true, 64, 128, 8)
} else if (args.data_type == DType::bf16) {
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(DType::bf16, false, 64, 128, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(DType::bf16, true, 64, 128, 8)
}
return {false, -1.f};
}
std::pair<bool, float> dispatch_decode_d64_gqa8_m64(
const unified_attention_args& args, const stream_config& config)
{
@@ -287,34 +235,6 @@ std::pair<bool, float> dispatch_decode_d64_gqa8_m64(
return {false, -1.f};
}
std::pair<bool, float> dispatch_decode_d64_gqa8_m64_p32(
const unified_attention_args& args, const stream_config& config)
{
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
if (args.data_type == DType::fp16) {
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(DType::fp16, false, 64, 64, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(DType::fp16, true, 64, 64, 8)
} else if (args.data_type == DType::bf16) {
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(DType::bf16, false, 64, 64, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(DType::bf16, true, 64, 64, 8)
}
return {false, -1.f};
}
std::pair<bool, float> dispatch_decode_d64_gqa8_m32_p32(
const unified_attention_args& args, const stream_config& config)
{
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
if (args.data_type == DType::fp16) {
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType::fp16, false, 64, 32, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType::fp16, true, 64, 32, 8)
} else if (args.data_type == DType::bf16) {
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType::bf16, false, 64, 32, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType::bf16, true, 64, 32, 8)
}
return {false, -1.f};
}
std::pair<bool, float> dispatch_decode_d64_gqa8_m16(
const unified_attention_args& args, const stream_config& config)
{
@@ -347,22 +267,16 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
switch (cfg.variant)
{
case KernelVariant::prefill_d128_mha: return dispatch_prefill_d128_mha(args, config);
case KernelVariant::decode_d128_mha_m128: return dispatch_decode_d128_mha_m128(args, config);
case KernelVariant::prefill_d64_gqa8: return dispatch_prefill_d64_gqa8(args, config);
case KernelVariant::decode_d64_gqa8_m128: return dispatch_decode_d64_gqa8_m128(args, config);
case KernelVariant::decode_d64_gqa8_m128_p32: return dispatch_decode_d64_gqa8_m128_p32(args, config);
case KernelVariant::decode_d64_gqa8_m64: return dispatch_decode_d64_gqa8_m64(args, config);
case KernelVariant::decode_d64_gqa8_m64_p32: return dispatch_decode_d64_gqa8_m64_p32(args, config);
case KernelVariant::decode_d64_gqa8_m32_p32: return dispatch_decode_d64_gqa8_m32_p32(args, config);
case KernelVariant::decode_d64_gqa8_m16: return dispatch_decode_d64_gqa8_m16(args, config);
case KernelVariant::prefill_d128_mha: return dispatch_prefill_d128_mha(args, config);
case KernelVariant::decode_d128_mha_m128: return dispatch_decode_d128_mha_m128(args, config);
case KernelVariant::prefill_d64_gqa8: return dispatch_prefill_d64_gqa8(args, config);
case KernelVariant::decode_d64_gqa8_m128: return dispatch_decode_d64_gqa8_m128(args, config);
case KernelVariant::decode_d64_gqa8_m64: return dispatch_decode_d64_gqa8_m64(args, config);
case KernelVariant::decode_d64_gqa8_m16: return dispatch_decode_d64_gqa8_m16(args, config);
}
return std::make_pair(false, -1.f);
}
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_TINY
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM

View File

@@ -316,68 +316,6 @@ struct unified_attention_decode_tiny_kernel_traits
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
};
// bs32 decode traits: 2 warps, 16x16 MFMA, kBlockM=32, kBlockQ=4 for GQA-8.
// Used for block_size=32 decode: avoids the 1-warp pipeline race condition
// and reduces query waste from 87.5% (small tier kBlockQ=8) to 75% (kBlockQ=4).
template <unified_attention_args::data_type_enum DataType,
bool IsMasking,
index_t HeadSize_ = 64,
index_t BlockM_ = 32,
index_t NumQPerKV_ = 8,
index_t BlockSize_ = 32>
struct unified_attention_decode_bs32_kernel_traits
{
static constexpr auto date_type = DataType;
static constexpr bool is_masking = IsMasking;
static constexpr index_t kBlockM = BlockM_;
static constexpr index_t HEAD_SIZE = HeadSize_;
static constexpr index_t BLOCK_SIZE = BlockSize_;
static constexpr index_t num_queries_per_kv = NumQPerKV_;
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
using unified_attention_warp_gemm_shape = sequence<16, 16, 32>;
using unified_attention_block_warps = sequence<2, 1, 1>;
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
unified_attention_block_warps,
unified_attention_warp_gemm_shape,
unified_attention_block_warps,
unified_attention_warp_gemm_shape,
true>;
using unified_attention_traits = TileUnifiedAttentionTraits<true, false, -1>;
using unified_attention_mask = GenericAttentionMask<IsMasking, false>;
using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem<
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::lse_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::o_dtype,
unified_attention_shape,
unified_attention_mask,
unified_attention_traits>;
using unified_attention_pipeline =
UnifiedAttentionPipeline<unified_attention_pipeline_problem,
UnifiedAttentionPipelineDecodePolicy>;
using epilogue = Default2DEpilogue<
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::o_dtype,
true, true, true>>;
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
};
template <typename Kernel, bool UseDecodeGrid = false>
float unified_attention_kernel_launch(const unified_attention_args& args,
const stream_config& config)