Fix more format

This commit is contained in:
PoYen, Chen
2024-08-16 10:32:17 +00:00
parent 5728c0be65
commit 2523c8e36c
4 changed files with 132 additions and 126 deletions

View File

@@ -256,8 +256,8 @@ int override_num_splits_if_necessary(
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
@@ -307,9 +307,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
if((0 < seqlen_knew || 0 < page_block_size) && mode != mode_enum::batch) {
std::cerr << "kvcache enabled. ignoring the 'mode' option"
<< std::endl;
if((0 < seqlen_knew || 0 < page_block_size) && mode != mode_enum::batch)
{
std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl;
mode = mode_enum::batch;
}
@@ -780,7 +780,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
else // fmha_fwd_traits or fmha_splitkv_traits
{
traits.is_group_mode = (mode == mode_enum::group);
traits.is_group_mode = (mode == mode_enum::group);
traits.mask_type = mask.type;
traits.bias_type = bias.type;
traits.has_lse = lse;
@@ -871,12 +871,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.k_ptr = k_buf.GetDeviceBuffer();
args.v_ptr = v_buf.GetDeviceBuffer();
args.batch = batch;
args.seqlen_q = shape_seqlen_q;
args.hdim_q = hdim_q;
args.hdim_v = hdim_v;
args.nhead_q = nhead;
args.nhead_k = nhead_k;
args.batch = batch;
args.seqlen_q = shape_seqlen_q;
args.hdim_q = hdim_q;
args.hdim_v = hdim_v;
args.nhead_q = nhead;
args.nhead_k = nhead_k;
args.stride_q = stride_q;
args.stride_k = stride_k;
@@ -919,9 +919,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.lse_ptr = lse_buf.GetDeviceBuffer();
args.o_ptr = o_buf.GetDeviceBuffer();
args.seqstart_q_ptr = (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
args.seqstart_k_ptr = (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
args.seqlen_k_ptr = (0 < seqlen_knew || 0 < page_block_size || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr);
args.seqstart_q_ptr =
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
args.seqstart_k_ptr =
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
args.seqlen_k_ptr = (0 < seqlen_knew || 0 < page_block_size || 0 <= k_paddings_[0]
? seqlen_k_buf.GetDeviceBuffer()
: nullptr);
args.seqlen_k = (args.seqlen_k_ptr == nullptr ? shape_seqlen_k : -1);
args.max_seqlen_q = max_seqlen_q;

View File

@@ -164,9 +164,8 @@ struct fmha_fwd_splitkv_args
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void*
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr, or
// kvcache is used
const void* seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not
// nullptr, or kvcache is used
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k; // only used if 'seqlen_k_ptr' is nullptr
@@ -521,38 +520,38 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.knew_ptr,
args.v_ptr,
args.vnew_ptr,
args.seqlen_q,
args.seqlen_k_ptr,
args.seqlen_knew,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.rotary_cos_ptr,
args.rotary_sin_ptr,
args.rotary_dim,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.stride_q,
args.stride_k,
args.stride_knew,
args.stride_v,
args.stride_vnew,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_knew,
args.nhead_stride_v,
args.nhead_stride_vnew,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_knew,
args.batch_stride_v,
args.batch_stride_vnew);
args.k_ptr,
args.knew_ptr,
args.v_ptr,
args.vnew_ptr,
args.seqlen_q,
args.seqlen_k_ptr,
args.seqlen_knew,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.rotary_cos_ptr,
args.rotary_sin_ptr,
args.rotary_dim,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.stride_q,
args.stride_k,
args.stride_knew,
args.stride_v,
args.stride_vnew,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_knew,
args.nhead_stride_v,
args.nhead_stride_vnew,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_knew,
args.batch_stride_v,
args.batch_stride_vnew);
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.seqlen_knew);

View File

@@ -128,78 +128,78 @@ struct FmhaFwdAppendKVKernel
ck_tile::index_t rotary_dim;
};
struct Kargs : BasicKargs,
std::conditional_t<kApplyRoPE, RoPEKargs, EmptyKargs<0>>
{};
__host__ static constexpr Kargs
MakeKargs(void* q_ptr,
void* k_ptr,
const void* knew_ptr,
void* v_ptr,
const void* vnew_ptr,
ck_tile::index_t seqlen_q,
const void* seqlen_k_ptr,
ck_tile::index_t seqlen_knew,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
const void* rotary_cos_ptr,
const void* rotary_sin_ptr,
ck_tile::index_t rotary_dim,
const void* block_table_ptr,
ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_knew,
ck_tile::index_t stride_v,
ck_tile::index_t stride_vnew,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_knew,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_vnew,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_knew,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_vnew)
struct Kargs : BasicKargs, std::conditional_t<kApplyRoPE, RoPEKargs, EmptyKargs<0>>
{
Kargs kargs{{q_ptr,
k_ptr,
knew_ptr,
v_ptr,
vnew_ptr,
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
seqlen_q,
-1, // seqlen_k will be updated by content of seqlen_k_ptr
seqlen_knew,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
block_table_ptr,
batch_stride_block_table,
page_block_size,
stride_q,
stride_k,
stride_knew,
stride_v,
stride_vnew,
nhead_stride_q,
nhead_stride_k,
nhead_stride_knew,
nhead_stride_v,
nhead_stride_vnew,
batch_stride_q,
batch_stride_k,
batch_stride_knew,
batch_stride_v,
batch_stride_vnew}, // args for common karg
{} // placeholder for rope
};
};
__host__ static constexpr Kargs MakeKargs(void* q_ptr,
void* k_ptr,
const void* knew_ptr,
void* v_ptr,
const void* vnew_ptr,
ck_tile::index_t seqlen_q,
const void* seqlen_k_ptr,
ck_tile::index_t seqlen_knew,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
const void* rotary_cos_ptr,
const void* rotary_sin_ptr,
ck_tile::index_t rotary_dim,
const void* block_table_ptr,
ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_knew,
ck_tile::index_t stride_v,
ck_tile::index_t stride_vnew,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_knew,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_vnew,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_knew,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_vnew)
{
Kargs kargs{
{q_ptr,
k_ptr,
knew_ptr,
v_ptr,
vnew_ptr,
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
seqlen_q,
-1, // seqlen_k will be updated by content of seqlen_k_ptr
seqlen_knew,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
block_table_ptr,
batch_stride_block_table,
page_block_size,
stride_q,
stride_k,
stride_knew,
stride_v,
stride_vnew,
nhead_stride_q,
nhead_stride_k,
nhead_stride_knew,
nhead_stride_v,
nhead_stride_vnew,
batch_stride_q,
batch_stride_k,
batch_stride_knew,
batch_stride_v,
batch_stride_vnew}, // args for common karg
{} // placeholder for rope
};
if constexpr(kApplyRoPE)
{
@@ -229,11 +229,14 @@ struct FmhaFwdAppendKVKernel
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kM0);
const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kN0);
const long_index_t batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
const long_index_t batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
const long_index_t batch_offset_q =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
const long_index_t batch_offset_k =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
const long_index_t batch_offset_knew =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_knew;
const long_index_t batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
const long_index_t batch_offset_v =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
const long_index_t batch_offset_vnew =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_vnew;

View File

@@ -223,7 +223,7 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t batch,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified
const void* seqlen_k_ptr, // only used for (paged-) kvcache
const void* seqlen_k_ptr, // only used for (paged-) kvcache
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -495,7 +495,7 @@ struct FmhaFwdSplitKVKernel
}
// get real # queries & # keys under group mode
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier