This commit is contained in:
Juuso Korhonen
2025-12-05 10:52:15 +00:00
parent f01e964e22
commit 0501d37efe
2 changed files with 53 additions and 49 deletions

View File

@@ -177,38 +177,9 @@ float unified_attention_kernel_launch(const unified_attention_args& args,
assert(BLOCK_Q == BLOCK_M / args.num_queries_per_kv &&
"BLOCK_Q must equal BLOCK_M / num_queries_per_kv");
index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs;
auto kargs =
auto kargs_lambda =
[&]() {
if(Kernel::QuantEnum == UnifiedAttentionQuantScaleEnum::NO_SCALE)
{
return Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.o_ptr,
args.num_blks,
args.num_head_q,
args.num_queries_per_kv,
args.scale_s,
total_num_q_blocks,
args.query_stride_0,
args.query_stride_1,
args.stride_k_cache_0,
args.stride_k_cache_1,
args.stride_k_cache_2,
args.stride_k_cache_3,
args.stride_v_cache_0,
args.stride_v_cache_1,
args.stride_v_cache_2,
args.stride_v_cache_3,
args.output_stride_0,
args.output_stride_1,
args.block_tables_ptr,
args.block_table_stride,
args.seq_lens_ptr,
args.query_start_len_ptr,
args.num_seqs);
}
else
if constexpr(Kernel::kIsQuantized)
{
return Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
@@ -241,8 +212,40 @@ float unified_attention_kernel_launch(const unified_attention_args& args,
args.query_start_len_ptr,
args.num_seqs);
}
} dim3 grids =
Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks);
else
{
return Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.o_ptr,
args.num_blks,
args.num_head_q,
args.num_queries_per_kv,
args.scale_s,
total_num_q_blocks,
args.query_stride_0,
args.query_stride_1,
args.stride_k_cache_0,
args.stride_k_cache_1,
args.stride_k_cache_2,
args.stride_k_cache_3,
args.stride_v_cache_0,
args.stride_v_cache_1,
args.stride_v_cache_2,
args.stride_v_cache_3,
args.output_stride_0,
args.output_stride_1,
args.block_tables_ptr,
args.block_table_stride,
args.seq_lens_ptr,
args.query_start_len_ptr,
args.num_seqs);
}
};
auto kargs = kargs_lambda();
dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr index_t kBlockPerCu = Kernel::kBlockPerCu;

View File

@@ -25,6 +25,7 @@ struct UnifiedAttentionKernel
static constexpr ck_tile::index_t kBlockPerCu = UnifiedAttentionPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
using QDataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::QDataType>;
using KDataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::KDataType>;
using VDataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::VDataType>;
@@ -38,6 +39,7 @@ struct UnifiedAttentionKernel
static constexpr bool kPadHeadDimQ = UnifiedAttentionPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = UnifiedAttentionPipeline::kPadHeadDimV;
static constexpr auto QuantEnum = UnifiedAttentionPipeline::Problem::QuantEnum;
// TODO add yjese
static constexpr index_t HEAD_SIZE = UnifiedAttentionPipeline::HEAD_SIZE;
@@ -50,13 +52,21 @@ struct UnifiedAttentionKernel
static constexpr index_t BLOCK_Q = UnifiedAttentionPipeline::BLOCK_Q;
// BLOCK size for K seqlen
static constexpr index_t BLOCK_SIZE = UnifiedAttentionPipeline::BLOCK_SIZE;
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
static constexpr bool kIsQuantized = (QuantEnum != UnifiedAttentionQuantScaleEnum::NO_SCALE);
template <ck_tile::index_t I> // to avoid duplicated base class problem, introduce an template
// arg
struct UnifiedAttentionEmptyKargs
{
};
struct UnifiedAttentionQuantKargs
{
float scale_q;
float scale_k;
float scale_v;
float scale_out;
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
@@ -92,19 +102,10 @@ struct UnifiedAttentionKernel
ck_tile::index_t output_stride_1;
};
struct UnifiedAttentionQuantKargs
{
float scale_q;
float scale_k;
float scale_v;
float scale_out;
};
struct UnifiedAttentionVarlenKargs
: UnifiedAttentionCommonKargs,
std::conditional_t<QuantEnum != UnifiedAttentionQuantScaleEnum::NO_SCALE,
UnifiedAttentionQuantKargs,
UnifiedAttentionEmptyKargs<0>>
std::conditional_t<kIsQuantized, UnifiedAttentionQuantKargs, UnifiedAttentionEmptyKargs<0>>
{
const int32_t* block_tables_ptr;
ck_tile::index_t block_table_stride;
@@ -116,7 +117,7 @@ struct UnifiedAttentionKernel
using Kargs = UnifiedAttentionVarlenKargs;
template <bool Cond = QuantEnum == UnifiedAttentionQuantScaleEnum::NO_SCALE>
template <bool Cond = !kIsQuantized>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
@@ -153,7 +154,6 @@ struct UnifiedAttentionKernel
num_head_q,
num_queries_per_kv,
static_cast<float>(scale_s * ck_tile::log2e_v<>),
scale_s,
total_num_q_blocks,
query_stride_0,
query_stride_1,
@@ -167,6 +167,7 @@ struct UnifiedAttentionKernel
stride_v_cache_3,
output_stride_0,
output_stride_1},
{},
block_tables_ptr,
block_table_stride,
seq_lens_ptr,
@@ -176,7 +177,7 @@ struct UnifiedAttentionKernel
return kargs;
}
template <bool Cond = QuantEnum != UnifiedAttentionQuantScaleEnum::NO_SCALE>
template <bool Cond = kIsQuantized>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
@@ -488,7 +489,7 @@ struct UnifiedAttentionKernel
// both quantized
auto o_acc_tile = [&]() {
if(QuantEnum != UnifiedAttentionQuantScaleEnum::NO_SCALE)
if constexpr(kIsQuantized)
{
if(std::is_same_v<QDataType, KDataType>)
{