mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Fix more format
This commit is contained in:
@@ -256,8 +256,8 @@ int override_num_splits_if_necessary(
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
@@ -307,9 +307,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
|
||||
if((0 < seqlen_knew || 0 < page_block_size) && mode != mode_enum::batch) {
|
||||
std::cerr << "kvcache enabled. ignoring the 'mode' option"
|
||||
<< std::endl;
|
||||
if((0 < seqlen_knew || 0 < page_block_size) && mode != mode_enum::batch)
|
||||
{
|
||||
std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl;
|
||||
mode = mode_enum::batch;
|
||||
}
|
||||
|
||||
@@ -780,7 +780,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
else // fmha_fwd_traits or fmha_splitkv_traits
|
||||
{
|
||||
traits.is_group_mode = (mode == mode_enum::group);
|
||||
traits.is_group_mode = (mode == mode_enum::group);
|
||||
traits.mask_type = mask.type;
|
||||
traits.bias_type = bias.type;
|
||||
traits.has_lse = lse;
|
||||
@@ -871,12 +871,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
|
||||
args.batch = batch;
|
||||
args.seqlen_q = shape_seqlen_q;
|
||||
args.hdim_q = hdim_q;
|
||||
args.hdim_v = hdim_v;
|
||||
args.nhead_q = nhead;
|
||||
args.nhead_k = nhead_k;
|
||||
args.batch = batch;
|
||||
args.seqlen_q = shape_seqlen_q;
|
||||
args.hdim_q = hdim_q;
|
||||
args.hdim_v = hdim_v;
|
||||
args.nhead_q = nhead;
|
||||
args.nhead_k = nhead_k;
|
||||
|
||||
args.stride_q = stride_q;
|
||||
args.stride_k = stride_k;
|
||||
@@ -919,9 +919,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.lse_ptr = lse_buf.GetDeviceBuffer();
|
||||
args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
|
||||
args.seqstart_q_ptr = (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
|
||||
args.seqstart_k_ptr = (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
|
||||
args.seqlen_k_ptr = (0 < seqlen_knew || 0 < page_block_size || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr);
|
||||
args.seqstart_q_ptr =
|
||||
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
|
||||
args.seqstart_k_ptr =
|
||||
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
|
||||
args.seqlen_k_ptr = (0 < seqlen_knew || 0 < page_block_size || 0 <= k_paddings_[0]
|
||||
? seqlen_k_buf.GetDeviceBuffer()
|
||||
: nullptr);
|
||||
|
||||
args.seqlen_k = (args.seqlen_k_ptr == nullptr ? shape_seqlen_k : -1);
|
||||
args.max_seqlen_q = max_seqlen_q;
|
||||
|
||||
@@ -164,9 +164,8 @@ struct fmha_fwd_splitkv_args
|
||||
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void*
|
||||
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr, or
|
||||
// kvcache is used
|
||||
const void* seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not
|
||||
// nullptr, or kvcache is used
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k; // only used if 'seqlen_k_ptr' is nullptr
|
||||
@@ -521,38 +520,38 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.knew_ptr,
|
||||
args.v_ptr,
|
||||
args.vnew_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k_ptr,
|
||||
args.seqlen_knew,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.rotary_cos_ptr,
|
||||
args.rotary_sin_ptr,
|
||||
args.rotary_dim,
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_knew,
|
||||
args.stride_v,
|
||||
args.stride_vnew,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_knew,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_vnew,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_knew,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_vnew);
|
||||
args.k_ptr,
|
||||
args.knew_ptr,
|
||||
args.v_ptr,
|
||||
args.vnew_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k_ptr,
|
||||
args.seqlen_knew,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.rotary_cos_ptr,
|
||||
args.rotary_sin_ptr,
|
||||
args.rotary_dim,
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_knew,
|
||||
args.stride_v,
|
||||
args.stride_vnew,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_knew,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_vnew,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_knew,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_vnew);
|
||||
|
||||
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.seqlen_knew);
|
||||
|
||||
|
||||
@@ -128,78 +128,78 @@ struct FmhaFwdAppendKVKernel
|
||||
ck_tile::index_t rotary_dim;
|
||||
};
|
||||
|
||||
struct Kargs : BasicKargs,
|
||||
std::conditional_t<kApplyRoPE, RoPEKargs, EmptyKargs<0>>
|
||||
{};
|
||||
|
||||
__host__ static constexpr Kargs
|
||||
MakeKargs(void* q_ptr,
|
||||
void* k_ptr,
|
||||
const void* knew_ptr,
|
||||
void* v_ptr,
|
||||
const void* vnew_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t seqlen_knew,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
const void* rotary_cos_ptr,
|
||||
const void* rotary_sin_ptr,
|
||||
ck_tile::index_t rotary_dim,
|
||||
const void* block_table_ptr,
|
||||
ck_tile::index_t batch_stride_block_table,
|
||||
ck_tile::index_t page_block_size,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_knew,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_vnew,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_knew,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_vnew,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_knew,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_vnew)
|
||||
struct Kargs : BasicKargs, std::conditional_t<kApplyRoPE, RoPEKargs, EmptyKargs<0>>
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
knew_ptr,
|
||||
v_ptr,
|
||||
vnew_ptr,
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
|
||||
seqlen_q,
|
||||
-1, // seqlen_k will be updated by content of seqlen_k_ptr
|
||||
seqlen_knew,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
block_table_ptr,
|
||||
batch_stride_block_table,
|
||||
page_block_size,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_knew,
|
||||
stride_v,
|
||||
stride_vnew,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_knew,
|
||||
nhead_stride_v,
|
||||
nhead_stride_vnew,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_knew,
|
||||
batch_stride_v,
|
||||
batch_stride_vnew}, // args for common karg
|
||||
{} // placeholder for rope
|
||||
};
|
||||
};
|
||||
|
||||
__host__ static constexpr Kargs MakeKargs(void* q_ptr,
|
||||
void* k_ptr,
|
||||
const void* knew_ptr,
|
||||
void* v_ptr,
|
||||
const void* vnew_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t seqlen_knew,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
const void* rotary_cos_ptr,
|
||||
const void* rotary_sin_ptr,
|
||||
ck_tile::index_t rotary_dim,
|
||||
const void* block_table_ptr,
|
||||
ck_tile::index_t batch_stride_block_table,
|
||||
ck_tile::index_t page_block_size,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_knew,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_vnew,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_knew,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_vnew,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_knew,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_vnew)
|
||||
{
|
||||
Kargs kargs{
|
||||
{q_ptr,
|
||||
k_ptr,
|
||||
knew_ptr,
|
||||
v_ptr,
|
||||
vnew_ptr,
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
|
||||
seqlen_q,
|
||||
-1, // seqlen_k will be updated by content of seqlen_k_ptr
|
||||
seqlen_knew,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
block_table_ptr,
|
||||
batch_stride_block_table,
|
||||
page_block_size,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_knew,
|
||||
stride_v,
|
||||
stride_vnew,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_knew,
|
||||
nhead_stride_v,
|
||||
nhead_stride_vnew,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_knew,
|
||||
batch_stride_v,
|
||||
batch_stride_vnew}, // args for common karg
|
||||
{} // placeholder for rope
|
||||
};
|
||||
|
||||
if constexpr(kApplyRoPE)
|
||||
{
|
||||
@@ -229,11 +229,14 @@ struct FmhaFwdAppendKVKernel
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kM0);
|
||||
const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kN0);
|
||||
|
||||
const long_index_t batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
const long_index_t batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
const long_index_t batch_offset_q =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
const long_index_t batch_offset_k =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
const long_index_t batch_offset_knew =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_knew;
|
||||
const long_index_t batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
const long_index_t batch_offset_v =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
const long_index_t batch_offset_vnew =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_vnew;
|
||||
|
||||
|
||||
@@ -223,7 +223,7 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t batch,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified
|
||||
const void* seqlen_k_ptr, // only used for (paged-) kvcache
|
||||
const void* seqlen_k_ptr, // only used for (paged-) kvcache
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -495,7 +495,7 @@ struct FmhaFwdSplitKVKernel
|
||||
}
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
|
||||
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
|
||||
Reference in New Issue
Block a user