Fixes compilation issues for unified attention

This commit is contained in:
Amir Ghamarian
2026-03-27 08:03:06 -05:00
parent eb3011c525
commit 5cd4b441ab
6 changed files with 126 additions and 53 deletions

View File

@@ -26,7 +26,7 @@
#include "mask.hpp"
// const ck_tile::index_t page_blk_size = 32;
const ck_tile::index_t num_queries_per_kv = 1;
// num_queries_per_kv is now a runtime arg (see parse_cmd_args)
auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParser>
{
@@ -34,10 +34,8 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
arg_parser
.insert("prec", "bf16", "data type. fp16/bf16")
// .insert("b", "3", "batch size")
.insert("h_k",
"8",
"num head for k/v. num head for q is " + std::to_string(num_queries_per_kv) +
" times this")
.insert("nqpkv", "1", "num queries per kv head (GQA ratio, e.g. 1 for MHA, 8 for GQA-8)")
.insert("h_k", "8", "num head for k/v. num head for q is nqpkv times this")
.insert("s", "3328", "max seqlen_q")
.insert("s_k", "-1", "max seqlen_k, -1 means equal to s")
.insert("nb", "1024", "num_blks")
@@ -130,7 +128,7 @@ struct Problem
: ck_tile::unified_attention_args::data_type_enum::bf16;
num_blks = args.get_int("nb");
nhead_kv = args.get_int("h_k");
// TODO: support other GQA/MQA cases than just 4x
num_queries_per_kv = args.get_int("nqpkv");
nhead_q = nhead_kv * num_queries_per_kv;
ck_tile::index_t max_seqlen_q = args.get_int("s");
@@ -192,6 +190,7 @@ struct Problem
ck_tile::index_t num_blks;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_kv;
ck_tile::index_t num_queries_per_kv;
ck_tile::index_t hdim;
ck_tile::index_t page_blk_size;
ck_tile::index_t num_tokens;
@@ -335,7 +334,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
args.data_type = problem.data_type;
args.num_seqs = problem.batch;
args.num_head_q = problem.nhead_q;
args.num_queries_per_kv = num_queries_per_kv;
args.num_queries_per_kv = problem.num_queries_per_kv;
args.page_blk_size = problem.page_blk_size;
args.mask_type = 2;
args.hdim = problem.hdim;

View File

@@ -18,47 +18,43 @@ std::ostream& operator<<(std::ostream& stream,
}
}
// Helper macro to reduce dispatch boilerplate.
// Dispatches based on DataType, IsMasking, HeadSize, BlockM, NumQPerKV.
#define DISPATCH_UNIFIED_ATTENTION(DType, IsMask, HSize, BM, NQPKV) \
{ \
using kernel_traits = unified_attention_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
}
std::pair<bool, float> unified_attention(const unified_attention_args& args,
const stream_config& config)
{
if(args.data_type == unified_attention_args::data_type_enum::fp16)
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
// Route based on (data_type, mask, hdim, num_queries_per_kv).
// Only d128 MHA (8 warps, kBlockM=256) instances available.
// Decode-tuned instances require pipeline changes (NumWarpGroups must == 2,
// which means exactly 8 warps; fewer warps are not supported).
if(args.hdim == 128 && args.num_queries_per_kv == 1)
{
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
if(args.data_type == unified_attention_args::data_type_enum::fp16)
{
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16,
false>;
return unified_attention_kernel_dispatch<kernel_traits>(args, config);
if(!is_mask) DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::fp16, false, 128, 256, 1)
else DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::fp16, true, 128, 256, 1)
}
else
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
{
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, true>;
return unified_attention_kernel_dispatch<kernel_traits>(args, config);
}
}
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
{
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
{
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16,
false>;
return unified_attention_kernel_dispatch<kernel_traits>(args, config);
}
else
{
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, true>;
return unified_attention_kernel_dispatch<kernel_traits>(args, config);
if(!is_mask) DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::bf16, false, 128, 256, 1)
else DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::bf16, true, 128, 256, 1)
}
}
std::cerr << "unified_attention: no matching kernel instance for hdim=" << args.hdim
<< " num_queries_per_kv=" << args.num_queries_per_kv
<< " data_type=" << args.data_type << " mask_type=" << args.mask_type << std::endl;
return std::make_pair(false, -1.f);
}
#undef DISPATCH_UNIFIED_ATTENTION
} // namespace ck_tile

View File

@@ -52,24 +52,29 @@ struct unified_attention_problem_traits<unified_attention_args::data_type_enum::
using lse_dtype = float;
};
template <unified_attention_args::data_type_enum DataType, bool IsMasking>
// Parameterized kernel traits: DataType, IsMasking, HeadSize, BlockM, NumQueriesPerKV
template <unified_attention_args::data_type_enum DataType,
bool IsMasking,
index_t HeadSize_ = 128,
index_t BlockM_ = 256,
index_t NumQPerKV_ = 1>
struct unified_attention_kernel_traits
{
static constexpr auto date_type = DataType;
static constexpr bool is_masking = IsMasking;
static constexpr index_t kBlockM = 256;
static constexpr index_t kBlockM = BlockM_;
static constexpr index_t BLOCK_SIZE = 32;
static constexpr index_t HEAD_SIZE = 128;
static constexpr index_t HEAD_SIZE = HeadSize_;
// TODO please fix this to support also other num_queries_per_kv
static constexpr index_t num_queries_per_kv = 1;
static constexpr index_t num_queries_per_kv = NumQPerKV_;
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
// kBlockM kBlockQ BLOCK_SIZE HEAD_SIZE
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
// need to have 8 warps per workgroup to have warp specialization
// 8 warps for warp specialization; kBlockM must be 8 * 32 = 256
using unified_attention_block_warps = sequence<8, 1, 1>;
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
@@ -115,17 +120,70 @@ struct unified_attention_kernel_traits
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
};
// Decode-tuned traits: fewer warps, smaller kBlockM for low-token workloads.
// NOTE: Currently cannot compile due to pipeline constraint (NumWarpGroups must == 2).
// Kept for future pipeline work.
template <unified_attention_args::data_type_enum DataType,
bool IsMasking,
index_t HeadSize_ = 128,
index_t BlockM_ = 64,
index_t NumQPerKV_ = 1>
struct unified_attention_decode_kernel_traits
{
static constexpr auto date_type = DataType;
static constexpr bool is_masking = IsMasking;
static constexpr index_t kBlockM = BlockM_;
static constexpr index_t BLOCK_SIZE = 32;
static constexpr index_t HEAD_SIZE = HeadSize_;
static constexpr index_t num_queries_per_kv = NumQPerKV_;
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
using unified_attention_block_warps = sequence<2, 1, 1>;
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
unified_attention_block_warps,
unified_attention_warp_gemm_shape,
unified_attention_block_warps,
unified_attention_warp_gemm_shape,
true>;
using unified_attention_traits = TileUnifiedAttentionTraits<true, false, -1>;
using unified_attention_mask = GenericAttentionMask<IsMasking, false>;
using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem<
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::lse_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::o_dtype,
unified_attention_shape,
unified_attention_mask,
unified_attention_traits>;
using unified_attention_pipeline = UnifiedAttentionPipeline<unified_attention_pipeline_problem>;
using epilogue = Default2DEpilogue<
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::o_dtype,
true, true, true>>;
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
};
template <typename Kernel>
float unified_attention_kernel_launch(const unified_attention_args& args,
const stream_config& config)
{
index_t kBlockQ = Kernel::kBlockQ;
assert(args.num_queries_per_kv == Kernel::num_queries_per_kv &&
"argument num_queries_per_kv must equal compiled num_queries_per_kv");
assert(args.BLOCK_SIZE == Kernel::BLOCK_SIZE &&
"argument BLOCK_SIZE must equal compiled BLOCK_SIZE");
assert(kBlockQ == kBlockM / args.num_queries_per_kv &&
"kBlockQ must equal kBlockM / num_queries_per_kv");
constexpr index_t kBlockQ = Kernel::kBlockQ;
index_t total_num_q_blocks = args.num_tokens / kBlockQ + args.num_seqs;
auto kargs = Kernel::MakeKargs(args.q_ptr,
args.k_ptr,

View File

@@ -20,7 +20,14 @@ CK_TILE_HOST void reference_batched_masking(HostTensor<CDataType>& c_b_m_n, cons
{
for(int m = 0; m < M; ++m)
{
if(mask.IsOutOfSinkBound(m, n))
const bool is_out_of_bound = [&]() {
if constexpr(requires { mask.IsOutOfSinkBound(m, n); })
return mask.IsOutOfSinkBound(m, n);
else
return mask.IsOutOfBound(m, n);
}();
if(is_out_of_bound)
c_b_m_n(batch, m, n) = -ck_tile::numeric<CDataType>::infinity();
}
}

View File

@@ -257,6 +257,19 @@ struct GenericAttentionMask
index_t repeat_idx;
};
template <typename>
struct is_generic_attention_mask : std::false_type
{
};
template <bool IsMasking, bool IsLocal>
struct is_generic_attention_mask<GenericAttentionMask<IsMasking, IsLocal>> : std::true_type
{
};
template <typename Mask>
static constexpr bool is_generic_attention_mask_v = is_generic_attention_mask<Mask>::value;
// TODO: prefer use this function in host code
// can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask

View File

@@ -40,9 +40,9 @@ struct TileUnifiedAttentionShape
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
static constexpr index_t NumGemm0Warps =
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
reduce_on_sequence(Gemm0BlockWarps{}, multiplies<>{}, number<1>{});
static constexpr index_t NumGemm1Warps =
reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{});
reduce_on_sequence(Gemm1BlockWarps{}, multiplies<>{}, number<1>{});
static_assert(NumGemm1Warps % NumGemm0Warps == 0);
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);