Add num_splits option and dummy split-kv api method

This commit is contained in:
PoYen, Chen
2024-05-31 08:59:42 +00:00
parent abc7e7ed30
commit c928fefaae
2 changed files with 193 additions and 1 deletions

View File

@@ -114,6 +114,9 @@ auto create_args(int argc, char* argv[])
.insert("drop_seed", "1", "seed for random number generator")
.insert("drop_offset", "0", "offset for random number generator")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("num_splits",
"1",
"# of splits for key/value. 0 to determine actual number by heuristic")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
@@ -155,6 +158,20 @@ auto get_elimit<ck_tile::fp8_t>(std::string init_method)
}
}
float fmha_fwd_dispatch(fmha_fwd_traits traits,
fmha_fwd_args args,
const ck_tile::stream_config& config)
{
if(1 < args.num_splits)
{
return fmha_fwd_splitkv(traits, args, config);
}
else
{
return fmha_fwd(traits, args, config);
}
};
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
@@ -260,6 +277,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
seed.reset();
}
int num_splits = arg_parser.get_int("num_splits");
int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname");
@@ -361,6 +380,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
: std::array<ck_tile::index_t, 2>{batch, nhead})
: std::array<ck_tile::index_t, 2>{1, 1});
ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits
? std::array<ck_tile::index_t, 4>{num_splits, shape_batch, nhead, shape_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits ? std::array<ck_tile::index_t, 5>{num_splits,
shape_batch,
nhead,
shape_seqlen_q,
hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
// self define lse data layout as [shape_batch, nhead, shape_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}
@@ -443,6 +475,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
@@ -553,6 +587,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
: bias_buf.GetDeviceBuffer(),
randval_buf.GetDeviceBuffer(),
lse_acc_buf.GetDeviceBuffer(),
o_acc_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
@@ -566,6 +602,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
hdim_v,
nhead,
nhead_k,
num_splits,
scale_s,
scale_p,
scale_o,
@@ -598,7 +635,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{drop_seed, drop_offset}};
}();
float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config);
float ave_time = fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config);
if(ave_time < 0)
{

View File

@@ -10,6 +10,7 @@
#include "mask.hpp"
#include "bias.hpp"
#include <type_traits>
#include <iostream>
template <typename DataType>
struct FmhaFwdTypeConfig;
@@ -93,6 +94,8 @@ struct fmha_fwd_args
const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer
void* rand_val_ptr;
void* lse_acc_ptr;
void* o_acc_ptr;
void* lse_ptr;
void* o_ptr;
const void* seqstart_q_ptr;
@@ -106,6 +109,7 @@ struct fmha_fwd_args
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
ck_tile::index_t num_splits;
float scale_s;
float scale_p;
float scale_o;
@@ -234,6 +238,149 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
return ck_tile::make_tuple(kargs, grids);
}
#if 0
template <typename FmhaFwdSplitKVKernel>
auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaFwdSplitKVKernel::kIsGroupMode)
{
return FmhaFwdSplitKVKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_acc_ptr,
args.o_acc_ptr,
args.batch,
args.nhead,
args.max_seqlen_q,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q / args.nhead_k,
args.num_splits,
args.scale_s,
args.scale_p,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.window_size_left,
args.window_size_right,
args.mask_type);
}
else
{ // create batch mode kernel arguments
return FmhaFwdSplitKVKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_acc_ptr,
args.o_acc_ptr,
args.batch,
args.nhead,
args.max_seqlen_q,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q / args.nhead_k,
args.num_splits,
args.scale_s,
args.scale_p,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.window_size_left,
args.window_size_right,
args.mask_type);
}
}();
dim3 grids = FmhaFwdSplitKVKernel::GridSize(
args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.num_splits);
return ck_tile::make_tuple(kargs, grids);
}
template <typename FmhaFwdSplitKVCombineKernel>
auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel argumentszs
if constexpr(FmhaFwdSplitKVCombineKernel::kIsGroupMode)
{
return FmhaFwdSplitKVCombineKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_ptr,
args.o_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q / args.nhead_k,
args.num_splits,
args.scale_s,
args.scale_p,
args.scale_o,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_lse,
args.nhead_stride_o);
}
else
{ // create batch mode kernel arguments
return FmhaFwdSplitKVCombineKernel::MakeKargs(args.lse_acc_ptr,
args.o_acc_ptr,
args.lse_ptr,
args.o_ptr,
args.batch,
args.nhead,
args.max_seqlen_q,
args.seqlen_q,
args.hdim_v,
args.num_splits,
args.scale_o,
args.stride_o,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_lse,
args.batch_stride_o);
}
}();
dim3 grids = FmhaFwdSplitKVCombineKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
return ck_tile::make_tuple(kargs, grids);
}
#endif
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
@@ -282,6 +429,9 @@ struct fmha_fwd_traits_
template <typename Traits_>
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
template <typename Traits_>
float fmha_splitkv_(const ck_tile::stream_config&, fmha_fwd_args);
// This is the public API, will be generated by script
struct fmha_fwd_traits
{
@@ -298,3 +448,8 @@ struct fmha_fwd_traits
// TODO: padding check is inside this api
};
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
inline float fmha_fwd_splitkv(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&)
{
std::cout << __PRETTY_FUNCTION__ << std::endl;
return 0;
}