mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
fix
This commit is contained in:
@@ -177,38 +177,9 @@ float unified_attention_kernel_launch(const unified_attention_args& args,
|
||||
assert(BLOCK_Q == BLOCK_M / args.num_queries_per_kv &&
|
||||
"BLOCK_Q must equal BLOCK_M / num_queries_per_kv");
|
||||
index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs;
|
||||
auto kargs =
|
||||
auto kargs_lambda =
|
||||
[&]() {
|
||||
if(Kernel::QuantEnum == UnifiedAttentionQuantScaleEnum::NO_SCALE)
|
||||
{
|
||||
return Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.o_ptr,
|
||||
args.num_blks,
|
||||
args.num_head_q,
|
||||
args.num_queries_per_kv,
|
||||
args.scale_s,
|
||||
total_num_q_blocks,
|
||||
args.query_stride_0,
|
||||
args.query_stride_1,
|
||||
args.stride_k_cache_0,
|
||||
args.stride_k_cache_1,
|
||||
args.stride_k_cache_2,
|
||||
args.stride_k_cache_3,
|
||||
args.stride_v_cache_0,
|
||||
args.stride_v_cache_1,
|
||||
args.stride_v_cache_2,
|
||||
args.stride_v_cache_3,
|
||||
args.output_stride_0,
|
||||
args.output_stride_1,
|
||||
args.block_tables_ptr,
|
||||
args.block_table_stride,
|
||||
args.seq_lens_ptr,
|
||||
args.query_start_len_ptr,
|
||||
args.num_seqs);
|
||||
}
|
||||
else
|
||||
if constexpr(Kernel::kIsQuantized)
|
||||
{
|
||||
return Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
@@ -241,8 +212,40 @@ float unified_attention_kernel_launch(const unified_attention_args& args,
|
||||
args.query_start_len_ptr,
|
||||
args.num_seqs);
|
||||
}
|
||||
} dim3 grids =
|
||||
Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks);
|
||||
else
|
||||
{
|
||||
return Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.o_ptr,
|
||||
args.num_blks,
|
||||
args.num_head_q,
|
||||
args.num_queries_per_kv,
|
||||
args.scale_s,
|
||||
total_num_q_blocks,
|
||||
args.query_stride_0,
|
||||
args.query_stride_1,
|
||||
args.stride_k_cache_0,
|
||||
args.stride_k_cache_1,
|
||||
args.stride_k_cache_2,
|
||||
args.stride_k_cache_3,
|
||||
args.stride_v_cache_0,
|
||||
args.stride_v_cache_1,
|
||||
args.stride_v_cache_2,
|
||||
args.stride_v_cache_3,
|
||||
args.output_stride_0,
|
||||
args.output_stride_1,
|
||||
args.block_tables_ptr,
|
||||
args.block_table_stride,
|
||||
args.seq_lens_ptr,
|
||||
args.query_start_len_ptr,
|
||||
args.num_seqs);
|
||||
}
|
||||
};
|
||||
|
||||
auto kargs = kargs_lambda();
|
||||
|
||||
dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
constexpr index_t kBlockPerCu = Kernel::kBlockPerCu;
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ struct UnifiedAttentionKernel
|
||||
static constexpr ck_tile::index_t kBlockPerCu = UnifiedAttentionPipeline::kBlockPerCu;
|
||||
static_assert(kBlockPerCu > 0);
|
||||
|
||||
|
||||
using QDataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::QDataType>;
|
||||
using KDataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::KDataType>;
|
||||
using VDataType = ck_tile::remove_cvref_t<typename UnifiedAttentionPipeline::VDataType>;
|
||||
@@ -38,6 +39,7 @@ struct UnifiedAttentionKernel
|
||||
static constexpr bool kPadHeadDimQ = UnifiedAttentionPipeline::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = UnifiedAttentionPipeline::kPadHeadDimV;
|
||||
static constexpr auto QuantEnum = UnifiedAttentionPipeline::Problem::QuantEnum;
|
||||
|
||||
|
||||
// TODO add yjese
|
||||
static constexpr index_t HEAD_SIZE = UnifiedAttentionPipeline::HEAD_SIZE;
|
||||
@@ -50,13 +52,21 @@ struct UnifiedAttentionKernel
|
||||
static constexpr index_t BLOCK_Q = UnifiedAttentionPipeline::BLOCK_Q;
|
||||
// BLOCK size for K seqlen
|
||||
static constexpr index_t BLOCK_SIZE = UnifiedAttentionPipeline::BLOCK_SIZE;
|
||||
|
||||
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
|
||||
static constexpr bool kIsQuantized = (QuantEnum != UnifiedAttentionQuantScaleEnum::NO_SCALE);
|
||||
template <ck_tile::index_t I> // to avoid duplicated base class problem, introduce an template
|
||||
// arg
|
||||
struct UnifiedAttentionEmptyKargs
|
||||
{
|
||||
};
|
||||
|
||||
struct UnifiedAttentionQuantKargs
|
||||
{
|
||||
float scale_q;
|
||||
float scale_k;
|
||||
float scale_v;
|
||||
float scale_out;
|
||||
};
|
||||
|
||||
// kargs use aggregate initializer, so no constructor will provided
|
||||
// use inheritance to minimize karg size
|
||||
// user need to use MakeKargs() function to create kargs.
|
||||
@@ -92,19 +102,10 @@ struct UnifiedAttentionKernel
|
||||
ck_tile::index_t output_stride_1;
|
||||
};
|
||||
|
||||
struct UnifiedAttentionQuantKargs
|
||||
{
|
||||
float scale_q;
|
||||
float scale_k;
|
||||
float scale_v;
|
||||
float scale_out;
|
||||
};
|
||||
|
||||
struct UnifiedAttentionVarlenKargs
|
||||
: UnifiedAttentionCommonKargs,
|
||||
std::conditional_t<QuantEnum != UnifiedAttentionQuantScaleEnum::NO_SCALE,
|
||||
UnifiedAttentionQuantKargs,
|
||||
UnifiedAttentionEmptyKargs<0>>
|
||||
std::conditional_t<kIsQuantized, UnifiedAttentionQuantKargs, UnifiedAttentionEmptyKargs<0>>
|
||||
{
|
||||
const int32_t* block_tables_ptr;
|
||||
ck_tile::index_t block_table_stride;
|
||||
@@ -116,7 +117,7 @@ struct UnifiedAttentionKernel
|
||||
|
||||
using Kargs = UnifiedAttentionVarlenKargs;
|
||||
|
||||
template <bool Cond = QuantEnum == UnifiedAttentionQuantScaleEnum::NO_SCALE>
|
||||
template <bool Cond = !kIsQuantized>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
@@ -153,7 +154,6 @@ struct UnifiedAttentionKernel
|
||||
num_head_q,
|
||||
num_queries_per_kv,
|
||||
static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
scale_s,
|
||||
total_num_q_blocks,
|
||||
query_stride_0,
|
||||
query_stride_1,
|
||||
@@ -167,6 +167,7 @@ struct UnifiedAttentionKernel
|
||||
stride_v_cache_3,
|
||||
output_stride_0,
|
||||
output_stride_1},
|
||||
{},
|
||||
block_tables_ptr,
|
||||
block_table_stride,
|
||||
seq_lens_ptr,
|
||||
@@ -176,7 +177,7 @@ struct UnifiedAttentionKernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
template <bool Cond = QuantEnum != UnifiedAttentionQuantScaleEnum::NO_SCALE>
|
||||
template <bool Cond = kIsQuantized>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
@@ -488,7 +489,7 @@ struct UnifiedAttentionKernel
|
||||
// both quantized
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
if(QuantEnum != UnifiedAttentionQuantScaleEnum::NO_SCALE)
|
||||
if constexpr(kIsQuantized)
|
||||
{
|
||||
if(std::is_same_v<QDataType, KDataType>)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user