mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Fixes compilation issues for unified attention
This commit is contained in:
@@ -26,7 +26,7 @@
|
||||
#include "mask.hpp"
|
||||
|
||||
// const ck_tile::index_t page_blk_size = 32;
|
||||
const ck_tile::index_t num_queries_per_kv = 1;
|
||||
// num_queries_per_kv is now a runtime arg (see parse_cmd_args)
|
||||
|
||||
auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParser>
|
||||
{
|
||||
@@ -34,10 +34,8 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
|
||||
arg_parser
|
||||
.insert("prec", "bf16", "data type. fp16/bf16")
|
||||
// .insert("b", "3", "batch size")
|
||||
.insert("h_k",
|
||||
"8",
|
||||
"num head for k/v. num head for q is " + std::to_string(num_queries_per_kv) +
|
||||
" times this")
|
||||
.insert("nqpkv", "1", "num queries per kv head (GQA ratio, e.g. 1 for MHA, 8 for GQA-8)")
|
||||
.insert("h_k", "8", "num head for k/v. num head for q is nqpkv times this")
|
||||
.insert("s", "3328", "max seqlen_q")
|
||||
.insert("s_k", "-1", "max seqlen_k, -1 means equal to s")
|
||||
.insert("nb", "1024", "num_blks")
|
||||
@@ -130,7 +128,7 @@ struct Problem
|
||||
: ck_tile::unified_attention_args::data_type_enum::bf16;
|
||||
num_blks = args.get_int("nb");
|
||||
nhead_kv = args.get_int("h_k");
|
||||
// TODO: support other GQA/MQA cases than just 4x
|
||||
num_queries_per_kv = args.get_int("nqpkv");
|
||||
nhead_q = nhead_kv * num_queries_per_kv;
|
||||
|
||||
ck_tile::index_t max_seqlen_q = args.get_int("s");
|
||||
@@ -192,6 +190,7 @@ struct Problem
|
||||
ck_tile::index_t num_blks;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_kv;
|
||||
ck_tile::index_t num_queries_per_kv;
|
||||
ck_tile::index_t hdim;
|
||||
ck_tile::index_t page_blk_size;
|
||||
ck_tile::index_t num_tokens;
|
||||
@@ -335,7 +334,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
args.data_type = problem.data_type;
|
||||
args.num_seqs = problem.batch;
|
||||
args.num_head_q = problem.nhead_q;
|
||||
args.num_queries_per_kv = num_queries_per_kv;
|
||||
args.num_queries_per_kv = problem.num_queries_per_kv;
|
||||
args.page_blk_size = problem.page_blk_size;
|
||||
args.mask_type = 2;
|
||||
args.hdim = problem.hdim;
|
||||
|
||||
@@ -18,47 +18,43 @@ std::ostream& operator<<(std::ostream& stream,
|
||||
}
|
||||
}
|
||||
|
||||
// Helper macro to reduce dispatch boilerplate.
|
||||
// Dispatches based on DataType, IsMasking, HeadSize, BlockM, NumQPerKV.
|
||||
#define DISPATCH_UNIFIED_ATTENTION(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
const stream_config& config)
|
||||
{
|
||||
if(args.data_type == unified_attention_args::data_type_enum::fp16)
|
||||
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
|
||||
|
||||
// Route based on (data_type, mask, hdim, num_queries_per_kv).
|
||||
// Only d128 MHA (8 warps, kBlockM=256) instances available.
|
||||
// Decode-tuned instances require pipeline changes (NumWarpGroups must == 2,
|
||||
// which means exactly 8 warps; fewer warps are not supported).
|
||||
if(args.hdim == 128 && args.num_queries_per_kv == 1)
|
||||
{
|
||||
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
|
||||
if(args.data_type == unified_attention_args::data_type_enum::fp16)
|
||||
{
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16,
|
||||
false>;
|
||||
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config);
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::fp16, false, 128, 256, 1)
|
||||
else DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::fp16, true, 128, 256, 1)
|
||||
}
|
||||
else
|
||||
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
|
||||
{
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, true>;
|
||||
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config);
|
||||
}
|
||||
}
|
||||
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
|
||||
{
|
||||
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
|
||||
{
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16,
|
||||
false>;
|
||||
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config);
|
||||
}
|
||||
else
|
||||
{
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, true>;
|
||||
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config);
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::bf16, false, 128, 256, 1)
|
||||
else DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::bf16, true, 128, 256, 1)
|
||||
}
|
||||
}
|
||||
|
||||
std::cerr << "unified_attention: no matching kernel instance for hdim=" << args.hdim
|
||||
<< " num_queries_per_kv=" << args.num_queries_per_kv
|
||||
<< " data_type=" << args.data_type << " mask_type=" << args.mask_type << std::endl;
|
||||
return std::make_pair(false, -1.f);
|
||||
}
|
||||
|
||||
#undef DISPATCH_UNIFIED_ATTENTION
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -52,24 +52,29 @@ struct unified_attention_problem_traits<unified_attention_args::data_type_enum::
|
||||
using lse_dtype = float;
|
||||
};
|
||||
|
||||
template <unified_attention_args::data_type_enum DataType, bool IsMasking>
|
||||
// Parameterized kernel traits: DataType, IsMasking, HeadSize, BlockM, NumQueriesPerKV
|
||||
template <unified_attention_args::data_type_enum DataType,
|
||||
bool IsMasking,
|
||||
index_t HeadSize_ = 128,
|
||||
index_t BlockM_ = 256,
|
||||
index_t NumQPerKV_ = 1>
|
||||
struct unified_attention_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
|
||||
static constexpr index_t kBlockM = 256;
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t BLOCK_SIZE = 32;
|
||||
static constexpr index_t HEAD_SIZE = 128;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
|
||||
// TODO please fix this to support also other num_queries_per_kv
|
||||
static constexpr index_t num_queries_per_kv = 1;
|
||||
static constexpr index_t num_queries_per_kv = NumQPerKV_;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
|
||||
// kBlockM kBlockQ BLOCK_SIZE HEAD_SIZE
|
||||
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
|
||||
|
||||
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
|
||||
// need to have 8 warps per workgroup to have warp specialization
|
||||
// 8 warps for warp specialization; kBlockM must be 8 * 32 = 256
|
||||
using unified_attention_block_warps = sequence<8, 1, 1>;
|
||||
|
||||
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
|
||||
@@ -115,17 +120,70 @@ struct unified_attention_kernel_traits
|
||||
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
|
||||
};
|
||||
|
||||
// Decode-tuned traits: fewer warps, smaller kBlockM for low-token workloads.
|
||||
// NOTE: Currently cannot compile due to pipeline constraint (NumWarpGroups must == 2).
|
||||
// Kept for future pipeline work.
|
||||
template <unified_attention_args::data_type_enum DataType,
|
||||
bool IsMasking,
|
||||
index_t HeadSize_ = 128,
|
||||
index_t BlockM_ = 64,
|
||||
index_t NumQPerKV_ = 1>
|
||||
struct unified_attention_decode_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t BLOCK_SIZE = 32;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
|
||||
static constexpr index_t num_queries_per_kv = NumQPerKV_;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
|
||||
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
|
||||
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
|
||||
using unified_attention_block_warps = sequence<2, 1, 1>;
|
||||
|
||||
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
true>;
|
||||
|
||||
using unified_attention_traits = TileUnifiedAttentionTraits<true, false, -1>;
|
||||
using unified_attention_mask = GenericAttentionMask<IsMasking, false>;
|
||||
|
||||
using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem<
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::lse_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits>;
|
||||
|
||||
using unified_attention_pipeline = UnifiedAttentionPipeline<unified_attention_pipeline_problem>;
|
||||
|
||||
using epilogue = Default2DEpilogue<
|
||||
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
true, true, true>>;
|
||||
|
||||
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
float unified_attention_kernel_launch(const unified_attention_args& args,
|
||||
const stream_config& config)
|
||||
{
|
||||
index_t kBlockQ = Kernel::kBlockQ;
|
||||
assert(args.num_queries_per_kv == Kernel::num_queries_per_kv &&
|
||||
"argument num_queries_per_kv must equal compiled num_queries_per_kv");
|
||||
assert(args.BLOCK_SIZE == Kernel::BLOCK_SIZE &&
|
||||
"argument BLOCK_SIZE must equal compiled BLOCK_SIZE");
|
||||
assert(kBlockQ == kBlockM / args.num_queries_per_kv &&
|
||||
"kBlockQ must equal kBlockM / num_queries_per_kv");
|
||||
constexpr index_t kBlockQ = Kernel::kBlockQ;
|
||||
index_t total_num_q_blocks = args.num_tokens / kBlockQ + args.num_seqs;
|
||||
auto kargs = Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
|
||||
@@ -20,7 +20,14 @@ CK_TILE_HOST void reference_batched_masking(HostTensor<CDataType>& c_b_m_n, cons
|
||||
{
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
if(mask.IsOutOfSinkBound(m, n))
|
||||
const bool is_out_of_bound = [&]() {
|
||||
if constexpr(requires { mask.IsOutOfSinkBound(m, n); })
|
||||
return mask.IsOutOfSinkBound(m, n);
|
||||
else
|
||||
return mask.IsOutOfBound(m, n);
|
||||
}();
|
||||
|
||||
if(is_out_of_bound)
|
||||
c_b_m_n(batch, m, n) = -ck_tile::numeric<CDataType>::infinity();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -257,6 +257,19 @@ struct GenericAttentionMask
|
||||
index_t repeat_idx;
|
||||
};
|
||||
|
||||
template <typename>
|
||||
struct is_generic_attention_mask : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <bool IsMasking, bool IsLocal>
|
||||
struct is_generic_attention_mask<GenericAttentionMask<IsMasking, IsLocal>> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Mask>
|
||||
static constexpr bool is_generic_attention_mask_v = is_generic_attention_mask<Mask>::value;
|
||||
|
||||
// TODO: prefer use this function in host code
|
||||
// can convert from the FA style left/right to our generic coordinate
|
||||
// if left_size < 0 && right_size = 0, it is normal causal mask
|
||||
|
||||
@@ -40,9 +40,9 @@ struct TileUnifiedAttentionShape
|
||||
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
|
||||
|
||||
static constexpr index_t NumGemm0Warps =
|
||||
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
|
||||
reduce_on_sequence(Gemm0BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
static constexpr index_t NumGemm1Warps =
|
||||
reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{});
|
||||
reduce_on_sequence(Gemm1BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
static_assert(NumGemm1Warps % NumGemm0Warps == 0);
|
||||
|
||||
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
|
||||
|
||||
Reference in New Issue
Block a user