mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
Delete CK-UA bs32 variant family
The bs32 variants existed because pre-fix the pipeline required
kBlockN <= page_size, so page_size=32 forced a kBlockN=32 kernel
family. The multi-page-tile fix (commit 473869aba) lifted that
constraint and made kBlockN compile-time-independent of the runtime
page size, so the bs32 family is now redundant: every non-bs32 variant
is correct for any page_size.
This was validated in advance by forcing use_bs32=false in the
dispatcher and running the full correctness suite -- 236/240, identical
to baseline (the 4 remaining failures are the pre-existing int32-
overflow case, orthogonal).
Removes:
* 16 instances/unified_attention_*_bs32_*.cpp files
* unified_attention_decode_bs32_kernel_traits in unified_attention_impl.hpp
* 3 _BS32 dispatch macros in unified_attention.cpp
* 3 _p32 entries from the KernelVariant enum
* 3 dispatch_*_p32 helper functions and their switch cases
* the page_blk_size branch in select_config (now a pure tile-tier ladder)
Net: 12 fewer compile units (build time -6s on JIT), 78 fewer dispatcher
lines, and "which kernel runs?" is now driven purely by Q-tile shape.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -1,14 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 128, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,14 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 64, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,14 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 16, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,14 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 32, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,14 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 128, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,14 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 64, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,14 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 16, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,14 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 32, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,14 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 128, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,14 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 64, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,14 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 16, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,14 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 32, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,14 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 128, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,14 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 64, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,14 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 16, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,14 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 32, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -26,10 +26,10 @@ std::ostream& operator<<(std::ostream& stream,
|
||||
// 1. KernelVariant + select_config(args)
|
||||
// - KernelVariant is a flat enum of every compiled kernel instance the
|
||||
// module knows about. Each entry fixes the static knobs (kBlockM,
|
||||
// warp count, MFMA shape, pipeline policy, optional kBlockN override).
|
||||
// warp count, MFMA shape, pipeline policy).
|
||||
// - select_config() is the ONLY place where shape-based runtime
|
||||
// decisions live. It reads the problem (hdim, num_queries_per_kv,
|
||||
// page_blk_size, avg_q, max_seqlen_q) and emits a KernelConfig.
|
||||
// decisions live. It reads (hdim, num_queries_per_kv, avg_q,
|
||||
// max_seqlen_q) and emits a KernelConfig.
|
||||
//
|
||||
// 2. dispatch_<variant>() helpers + the final switch
|
||||
// - Each KernelVariant has a tiny helper that fans out over the
|
||||
@@ -38,12 +38,10 @@ std::ostream& operator<<(std::ostream& stream,
|
||||
// per-variant traits classes are unchanged from before; only the
|
||||
// selection logic moved.
|
||||
//
|
||||
// Phase-1 note: page-size is currently still a static axis in the enum
|
||||
// (the _p32 suffix marks the variant with kBlockN=32 that was originally
|
||||
// required when page_size < 64). The multi-page-tile fix in the pipeline
|
||||
// removed the underlying constraint, so a follow-up commit deletes the
|
||||
// _p32 (a.k.a. "bs32") family entirely. Doing it in two steps keeps each
|
||||
// diff easy to bisect against the test suite.
|
||||
// page_size is intentionally NOT part of this enum. The multi-page-tile
|
||||
// fix in the pipeline made the compile-time tile-N (kBlockN) independent
|
||||
// of the runtime page_blk_size, so every variant is correct for any page
|
||||
// size. Selection is driven purely by Q-tile shape.
|
||||
// =============================================================================
|
||||
|
||||
enum class KernelVariant {
|
||||
@@ -54,10 +52,7 @@ enum class KernelVariant {
|
||||
// d=64 GQA-8 (num_queries_per_kv = 8)
|
||||
prefill_d64_gqa8, // kBlockM=256, 8 warps, 32x32 mfma
|
||||
decode_d64_gqa8_m128, // kBlockM=128, 4 warps, 32x32 mfma
|
||||
decode_d64_gqa8_m128_p32, // kBlockM=128, 4 warps, 32x32 mfma, kBlockN=32
|
||||
decode_d64_gqa8_m64, // kBlockM=64, 2 warps, 32x32 mfma
|
||||
decode_d64_gqa8_m64_p32, // kBlockM=64, 2 warps, 32x32 mfma, kBlockN=32
|
||||
decode_d64_gqa8_m32_p32, // kBlockM=32, 2 warps, 16x16 mfma, kBlockN=32
|
||||
decode_d64_gqa8_m16, // kBlockM=16, 1 warp, 16x16 mfma
|
||||
};
|
||||
|
||||
@@ -114,28 +109,14 @@ KernelConfig select_config(const unified_attention_args& args)
|
||||
return cfg;
|
||||
}
|
||||
|
||||
// d=64 GQA-8 — full tile-tier ladder, with _p32 variants for the legacy
|
||||
// kBlockN=32 path used when page_blk_size < 64.
|
||||
// d=64 GQA-8 — pure tile-tier ladder. page_size has no influence here.
|
||||
if (args.hdim == 64 && args.num_queries_per_kv == 8)
|
||||
{
|
||||
const bool p32 = (args.page_blk_size < 64);
|
||||
|
||||
switch (select_tile_tier(args))
|
||||
{
|
||||
case tile_tier::tiny:
|
||||
// p32: 2-warp 16x16 (kBlockM=32) -- avoids the 1-warp+p32 race.
|
||||
// p64: 1-warp 16x16 (kBlockM=16) -- matches Triton BLOCK_M=16.
|
||||
cfg.variant = p32 ? KernelVariant::decode_d64_gqa8_m32_p32
|
||||
: KernelVariant::decode_d64_gqa8_m16;
|
||||
break;
|
||||
case tile_tier::small:
|
||||
cfg.variant = p32 ? KernelVariant::decode_d64_gqa8_m64_p32
|
||||
: KernelVariant::decode_d64_gqa8_m64;
|
||||
break;
|
||||
case tile_tier::medium:
|
||||
cfg.variant = p32 ? KernelVariant::decode_d64_gqa8_m128_p32
|
||||
: KernelVariant::decode_d64_gqa8_m128;
|
||||
break;
|
||||
case tile_tier::tiny: cfg.variant = KernelVariant::decode_d64_gqa8_m16; break;
|
||||
case tile_tier::small: cfg.variant = KernelVariant::decode_d64_gqa8_m64; break;
|
||||
case tile_tier::medium: cfg.variant = KernelVariant::decode_d64_gqa8_m128; break;
|
||||
}
|
||||
return cfg;
|
||||
}
|
||||
@@ -150,7 +131,7 @@ KernelConfig select_config(const unified_attention_args& args)
|
||||
// Each DISPATCH_* macro instantiates one (traits, dtype, mask, ...) combo and
|
||||
// returns. The per-variant helpers below pick the right macro family and fan
|
||||
// out over (dtype, mask). They look repetitive on purpose: a follow-up commit
|
||||
// will collapse the 5 traits classes into one templated `kernel_traits<V>`,
|
||||
// will collapse the 4 traits classes into one templated `kernel_traits<V>`,
|
||||
// at which point these helpers become one-liners.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
@@ -180,25 +161,6 @@ KernelConfig select_config(const unified_attention_args& args)
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
// block_size=32 dispatch macros (6th template arg = 32).
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32>; \
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_small_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
using DType = unified_attention_args::data_type_enum;
|
||||
@@ -259,20 +221,6 @@ std::pair<bool, float> dispatch_decode_d64_gqa8_m128(
|
||||
return {false, -1.f};
|
||||
}
|
||||
|
||||
std::pair<bool, float> dispatch_decode_d64_gqa8_m128_p32(
|
||||
const unified_attention_args& args, const stream_config& config)
|
||||
{
|
||||
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
|
||||
if (args.data_type == DType::fp16) {
|
||||
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(DType::fp16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(DType::fp16, true, 64, 128, 8)
|
||||
} else if (args.data_type == DType::bf16) {
|
||||
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(DType::bf16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(DType::bf16, true, 64, 128, 8)
|
||||
}
|
||||
return {false, -1.f};
|
||||
}
|
||||
|
||||
std::pair<bool, float> dispatch_decode_d64_gqa8_m64(
|
||||
const unified_attention_args& args, const stream_config& config)
|
||||
{
|
||||
@@ -287,34 +235,6 @@ std::pair<bool, float> dispatch_decode_d64_gqa8_m64(
|
||||
return {false, -1.f};
|
||||
}
|
||||
|
||||
std::pair<bool, float> dispatch_decode_d64_gqa8_m64_p32(
|
||||
const unified_attention_args& args, const stream_config& config)
|
||||
{
|
||||
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
|
||||
if (args.data_type == DType::fp16) {
|
||||
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(DType::fp16, false, 64, 64, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(DType::fp16, true, 64, 64, 8)
|
||||
} else if (args.data_type == DType::bf16) {
|
||||
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(DType::bf16, false, 64, 64, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(DType::bf16, true, 64, 64, 8)
|
||||
}
|
||||
return {false, -1.f};
|
||||
}
|
||||
|
||||
std::pair<bool, float> dispatch_decode_d64_gqa8_m32_p32(
|
||||
const unified_attention_args& args, const stream_config& config)
|
||||
{
|
||||
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
|
||||
if (args.data_type == DType::fp16) {
|
||||
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType::fp16, false, 64, 32, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType::fp16, true, 64, 32, 8)
|
||||
} else if (args.data_type == DType::bf16) {
|
||||
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType::bf16, false, 64, 32, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType::bf16, true, 64, 32, 8)
|
||||
}
|
||||
return {false, -1.f};
|
||||
}
|
||||
|
||||
std::pair<bool, float> dispatch_decode_d64_gqa8_m16(
|
||||
const unified_attention_args& args, const stream_config& config)
|
||||
{
|
||||
@@ -347,22 +267,16 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
|
||||
switch (cfg.variant)
|
||||
{
|
||||
case KernelVariant::prefill_d128_mha: return dispatch_prefill_d128_mha(args, config);
|
||||
case KernelVariant::decode_d128_mha_m128: return dispatch_decode_d128_mha_m128(args, config);
|
||||
case KernelVariant::prefill_d64_gqa8: return dispatch_prefill_d64_gqa8(args, config);
|
||||
case KernelVariant::decode_d64_gqa8_m128: return dispatch_decode_d64_gqa8_m128(args, config);
|
||||
case KernelVariant::decode_d64_gqa8_m128_p32: return dispatch_decode_d64_gqa8_m128_p32(args, config);
|
||||
case KernelVariant::decode_d64_gqa8_m64: return dispatch_decode_d64_gqa8_m64(args, config);
|
||||
case KernelVariant::decode_d64_gqa8_m64_p32: return dispatch_decode_d64_gqa8_m64_p32(args, config);
|
||||
case KernelVariant::decode_d64_gqa8_m32_p32: return dispatch_decode_d64_gqa8_m32_p32(args, config);
|
||||
case KernelVariant::decode_d64_gqa8_m16: return dispatch_decode_d64_gqa8_m16(args, config);
|
||||
case KernelVariant::prefill_d128_mha: return dispatch_prefill_d128_mha(args, config);
|
||||
case KernelVariant::decode_d128_mha_m128: return dispatch_decode_d128_mha_m128(args, config);
|
||||
case KernelVariant::prefill_d64_gqa8: return dispatch_prefill_d64_gqa8(args, config);
|
||||
case KernelVariant::decode_d64_gqa8_m128: return dispatch_decode_d64_gqa8_m128(args, config);
|
||||
case KernelVariant::decode_d64_gqa8_m64: return dispatch_decode_d64_gqa8_m64(args, config);
|
||||
case KernelVariant::decode_d64_gqa8_m16: return dispatch_decode_d64_gqa8_m16(args, config);
|
||||
}
|
||||
return std::make_pair(false, -1.f);
|
||||
}
|
||||
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_TINY
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM
|
||||
|
||||
@@ -316,68 +316,6 @@ struct unified_attention_decode_tiny_kernel_traits
|
||||
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
|
||||
};
|
||||
|
||||
// bs32 decode traits: 2 warps, 16x16 MFMA, kBlockM=32, kBlockQ=4 for GQA-8.
|
||||
// Used for block_size=32 decode: avoids the 1-warp pipeline race condition
|
||||
// and reduces query waste from 87.5% (small tier kBlockQ=8) to 75% (kBlockQ=4).
|
||||
template <unified_attention_args::data_type_enum DataType,
|
||||
bool IsMasking,
|
||||
index_t HeadSize_ = 64,
|
||||
index_t BlockM_ = 32,
|
||||
index_t NumQPerKV_ = 8,
|
||||
index_t BlockSize_ = 32>
|
||||
struct unified_attention_decode_bs32_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
static constexpr index_t BLOCK_SIZE = BlockSize_;
|
||||
|
||||
static constexpr index_t num_queries_per_kv = NumQPerKV_;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
|
||||
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
|
||||
using unified_attention_warp_gemm_shape = sequence<16, 16, 32>;
|
||||
using unified_attention_block_warps = sequence<2, 1, 1>;
|
||||
|
||||
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
true>;
|
||||
|
||||
using unified_attention_traits = TileUnifiedAttentionTraits<true, false, -1>;
|
||||
using unified_attention_mask = GenericAttentionMask<IsMasking, false>;
|
||||
|
||||
using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem<
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::lse_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits>;
|
||||
|
||||
using unified_attention_pipeline =
|
||||
UnifiedAttentionPipeline<unified_attention_pipeline_problem,
|
||||
UnifiedAttentionPipelineDecodePolicy>;
|
||||
|
||||
using epilogue = Default2DEpilogue<
|
||||
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
true, true, true>>;
|
||||
|
||||
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
|
||||
};
|
||||
|
||||
template <typename Kernel, bool UseDecodeGrid = false>
|
||||
float unified_attention_kernel_launch(const unified_attention_args& args,
|
||||
const stream_config& config)
|
||||
|
||||
Reference in New Issue
Block a user