From 3e2b69e16321dc1998d08480f736d25d3d57ba8c Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 8 Aug 2024 17:26:09 +0000 Subject: [PATCH] Display more info for specific kernels --- example/ck_tile/01_fmha/fmha_fwd.cpp | 29 ++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 631eff10db..b3ca41282c 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -119,12 +119,12 @@ auto create_args(int argc, char* argv[]) .insert("drop_seed", "1", "seed for random number generator") .insert("drop_offset", "0", "offset for random number generator") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") - .insert("num_splits", - "1", - "# of splits for key/value. 0 to determine actual number by heuristic") .insert( "rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all") .insert("rotary_interleaved", "1", "whether to apply interleaved RoPE") + .insert("num_splits", + "1", + "# of splits for key/value. 0 to determine actual number by heuristic") .insert("page_block_size", "0", "paged-kvcache block size. 0 means not use paged-kvcahe.") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel"); @@ -379,7 +379,7 @@ bool run(const ck_tile::ArgParser& arg_parser) std::string init_method = arg_parser.get_str("init"); - const ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); + ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); if constexpr(!(std::is_same_v || std::is_same_v)) { @@ -389,6 +389,14 @@ bool run(const ck_tile::ArgParser& arg_parser) return false; } } +#if !CK_TILE_FMHA_FWD_APPENDKV_API + else if(0 < rotary_dim) + { + std::cerr << "rotary embedding is not supported. ignoring the 'rotary_dim' option" + << std::endl; + rotary_dim = 0; + } +#endif if(!(rotary_dim <= hdim_q)) { std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl; @@ -730,10 +738,23 @@ bool run(const ck_tile::ArgParser& arg_parser) << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias << ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout; +#if CK_TILE_FMHA_FWD_APPENDKV_API + if(0 < rotary_dim) + { + std::cout << ", rotary_dim:" << rotary_dim << "(" + << (is_rotary_interleaved ? "inter" : "half") << ")"; + } +#endif +#if CK_TILE_FMHA_FWD_SPLITKV_API if(1 < num_splits) { std::cout << ", num_splits:" << num_splits; } + if(0 < page_block_size) + { + std::cout << ", page_block_size:" << page_block_size; + } +#endif std::cout << std::flush; float appendkv_ave_time = 0;