From 0fd7f8550488741151d35b71981d4789ddfb75e5 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 11 Jun 2024 14:20:03 +0000 Subject: [PATCH] Use shorter template parameter name --- example/ck_tile/01_fmha/fmha_fwd.hpp | 227 +++++++++++++-------------- 1 file changed, 113 insertions(+), 114 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 52036c267d..6a4775f392 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -238,148 +238,147 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) return ck_tile::make_tuple(kargs, grids); } -template +template auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) { assert(args.nhead_q % args.nhead_k == 0); auto kargs = [&] { // create group mode kernel arguments - if constexpr(FmhaFwdSplitKVKernel::kIsGroupMode) + if constexpr(Kernel::kIsGroupMode) { - return FmhaFwdSplitKVKernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.rand_val_ptr, - args.lse_acc_ptr, - args.o_acc_ptr, - args.batch, - args.nhead, - args.max_seqlen_q, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.num_splits, - args.scale_s, - args.scale_p, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + return Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_acc_ptr, + args.o_acc_ptr, + args.batch, + args.nhead, + args.max_seqlen_q, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_splits, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } else { // create batch mode kernel arguments - return FmhaFwdSplitKVKernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.rand_val_ptr, - args.lse_acc_ptr, - args.o_acc_ptr, - args.batch, - args.nhead, - args.max_seqlen_q, - args.seqlen_q, - args.seqlen_k, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.num_splits, - args.scale_s, - args.scale_p, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_bias, - args.batch_stride_randval, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + return Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_acc_ptr, + args.o_acc_ptr, + args.batch, + args.nhead, + args.max_seqlen_q, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_splits, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } }(); - dim3 grids = FmhaFwdSplitKVKernel::GridSize( - args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.num_splits); + dim3 grids = + Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.num_splits); return ck_tile::make_tuple(kargs, grids); } -template +template auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args) { assert(args.nhead_q % args.nhead_k == 0); auto kargs = [&] { // create group mode kernel argumentszs - if constexpr(FmhaFwdSplitKVCombineKernel::kIsGroupMode) + if constexpr(Kernel::kIsGroupMode) { - return FmhaFwdSplitKVCombineKernel::MakeKargs(args.lse_acc_ptr, - args.o_acc_ptr, - args.lse_ptr, - args.o_ptr, - args.batch, - args.nhead, - args.max_seqlen_q, - args.seqstart_q_ptr, - args.hdim_v, - args.num_splits, - args.scale_o, - args.stride_o, - args.nhead_stride_lse, - args.nhead_stride_o); + return Kernel::MakeKargs(args.lse_acc_ptr, + args.o_acc_ptr, + args.lse_ptr, + args.o_ptr, + args.batch, + args.nhead, + args.max_seqlen_q, + args.seqstart_q_ptr, + args.hdim_v, + args.num_splits, + args.scale_o, + args.stride_o, + args.nhead_stride_lse, + args.nhead_stride_o); } else { // create batch mode kernel arguments - return FmhaFwdSplitKVCombineKernel::MakeKargs(args.lse_acc_ptr, - args.o_acc_ptr, - args.lse_ptr, - args.o_ptr, - args.batch, - args.nhead, - args.max_seqlen_q, - args.seqlen_q, - args.hdim_v, - args.num_splits, - args.scale_o, - args.stride_o, - args.nhead_stride_lse, - args.nhead_stride_o, - args.batch_stride_lse, - args.batch_stride_o); + return Kernel::MakeKargs(args.lse_acc_ptr, + args.o_acc_ptr, + args.lse_ptr, + args.o_ptr, + args.batch, + args.nhead, + args.max_seqlen_q, + args.seqlen_q, + args.hdim_v, + args.num_splits, + args.scale_o, + args.stride_o, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_lse, + args.batch_stride_o); } }(); - dim3 grids = FmhaFwdSplitKVCombineKernel::GridSize( - args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); return ck_tile::make_tuple(kargs, grids); }