From c928fefaae41320b129e1b2f5e4f5b35ec43f8b6 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Fri, 31 May 2024 08:59:42 +0000 Subject: [PATCH] Add num_splits option and dummy split-kv api method --- example/ck_tile/01_fmha/fmha_fwd.cpp | 39 ++++++- example/ck_tile/01_fmha/fmha_fwd.hpp | 155 +++++++++++++++++++++++++++ 2 files changed, 193 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 5f887f0655..94e147ad63 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -114,6 +114,9 @@ auto create_args(int argc, char* argv[]) .insert("drop_seed", "1", "seed for random number generator") .insert("drop_offset", "0", "offset for random number generator") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("num_splits", + "1", + "# of splits for key/value. 0 to determine actual number by heuristic") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel"); @@ -155,6 +158,20 @@ auto get_elimit(std::string init_method) } } +float fmha_fwd_dispatch(fmha_fwd_traits traits, + fmha_fwd_args args, + const ck_tile::stream_config& config) +{ + if(1 < args.num_splits) + { + return fmha_fwd_splitkv(traits, args, config); + } + else + { + return fmha_fwd(traits, args, config); + } +}; + template bool run(const ck_tile::ArgParser& arg_parser) { @@ -260,6 +277,8 @@ bool run(const ck_tile::ArgParser& arg_parser) seed.reset(); } + int num_splits = arg_parser.get_int("num_splits"); + int stream_warmup = arg_parser.get_int("warmup"); int stream_repeat = arg_parser.get_int("repeat"); bool kname = arg_parser.get_bool("kname"); @@ -361,6 +380,19 @@ bool run(const ck_tile::ArgParser& arg_parser) : std::array{batch, nhead}) : std::array{1, 1}); + ck_tile::HostTensor lse_acc_host( + 1 < num_splits + ? std::array{num_splits, shape_batch, nhead, shape_seqlen_q} + : std::array{1, 1, 1, 1}); + + ck_tile::HostTensor o_acc_host( + 1 < num_splits ? std::array{num_splits, + shape_batch, + nhead, + shape_seqlen_q, + hdim_v} + : std::array{1, 1, 1, 1, 1}); + // self define lse data layout as [shape_batch, nhead, shape_seqlen_q] ck_tile::HostTensor lse_host( lse ? std::array{batch, nhead, max_seqlen_q} @@ -443,6 +475,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); @@ -553,6 +587,8 @@ bool run(const ck_tile::ArgParser& arg_parser) bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() : bias_buf.GetDeviceBuffer(), randval_buf.GetDeviceBuffer(), + lse_acc_buf.GetDeviceBuffer(), + o_acc_buf.GetDeviceBuffer(), lse_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(), @@ -566,6 +602,7 @@ bool run(const ck_tile::ArgParser& arg_parser) hdim_v, nhead, nhead_k, + num_splits, scale_s, scale_p, scale_o, @@ -598,7 +635,7 @@ bool run(const ck_tile::ArgParser& arg_parser) {drop_seed, drop_offset}}; }(); - float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config); + float ave_time = fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config); if(ave_time < 0) { diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 3594f61db9..ee4506591e 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -10,6 +10,7 @@ #include "mask.hpp" #include "bias.hpp" #include +#include template struct FmhaFwdTypeConfig; @@ -93,6 +94,8 @@ struct fmha_fwd_args const void* v_ptr; const void* bias_ptr; // bias or alibi_slope pointer void* rand_val_ptr; + void* lse_acc_ptr; + void* o_acc_ptr; void* lse_ptr; void* o_ptr; const void* seqstart_q_ptr; @@ -106,6 +109,7 @@ struct fmha_fwd_args ck_tile::index_t hdim_v; ck_tile::index_t nhead_q; ck_tile::index_t nhead_k; + ck_tile::index_t num_splits; float scale_s; float scale_p; float scale_o; @@ -234,6 +238,149 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) return ck_tile::make_tuple(kargs, grids); } +#if 0 +template +auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaFwdSplitKVKernel::kIsGroupMode) + { + return FmhaFwdSplitKVKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_acc_ptr, + args.o_acc_ptr, + args.batch, + args.nhead, + args.max_seqlen_q, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q / args.nhead_k, + args.num_splits, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.window_size_left, + args.window_size_right, + args.mask_type); + } + else + { // create batch mode kernel arguments + return FmhaFwdSplitKVKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_acc_ptr, + args.o_acc_ptr, + args.batch, + args.nhead, + args.max_seqlen_q, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q / args.nhead_k, + args.num_splits, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.window_size_left, + args.window_size_right, + args.mask_type); + } + }(); + + dim3 grids = FmhaFwdSplitKVKernel::GridSize( + args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.num_splits); + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel argumentszs + if constexpr(FmhaFwdSplitKVCombineKernel::kIsGroupMode) + { + return FmhaFwdSplitKVCombineKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q / args.nhead_k, + args.num_splits, + args.scale_s, + args.scale_p, + args.scale_o, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_lse, + args.nhead_stride_o); + } + else + { // create batch mode kernel arguments + return FmhaFwdSplitKVCombineKernel::MakeKargs(args.lse_acc_ptr, + args.o_acc_ptr, + args.lse_ptr, + args.o_ptr, + args.batch, + args.nhead, + args.max_seqlen_q, + args.seqlen_q, + args.hdim_v, + args.num_splits, + args.scale_o, + args.stride_o, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_lse, + args.batch_stride_o); + } + }(); + + dim3 grids = FmhaFwdSplitKVCombineKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q); + return ck_tile::make_tuple(kargs, grids); +} +#endif + // this is used to pattern-match internl kernel implementation, not to instantiate kernel template float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); +template +float fmha_splitkv_(const ck_tile::stream_config&, fmha_fwd_args); + // This is the public API, will be generated by script struct fmha_fwd_traits { @@ -298,3 +448,8 @@ struct fmha_fwd_traits // TODO: padding check is inside this api }; float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); +inline float fmha_fwd_splitkv(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&) +{ + std::cout << __PRETTY_FUNCTION__ << std::endl; + return 0; +}