diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 0196caceab..4b5635455c 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -278,8 +278,13 @@ bool run(const ck_tile::ArgParser& arg_parser) seed.reset(); } + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + if(hdim_v < 0) + hdim_v = hdim_q; + ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew"); -#if !CK_TILE_FMHA_FWD_APPENDKV_API +#if !(CK_TILE_FMHA_FWD_APPENDKV_API && CK_TILE_FMHA_FWD_SPLITKV_API) if(seqlen_knew != 0) { std::cerr << "kvcache is not supported. ignoring the 's_knew' option" << std::endl; @@ -291,6 +296,35 @@ bool run(const ck_tile::ArgParser& arg_parser) seqlen_knew = randint(1, arg_parser.get_int("s"), seed); } + ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); + if constexpr(!(std::is_same_v || + std::is_same_v)) + { + if(0 < rotary_dim) + { + std::cerr << "rotary embedding is only available for data type=fp16|bf16" << std::endl; + return false; + } + } +#if !(CK_TILE_FMHA_FWD_APPENDKV_API && CK_TILE_FMHA_FWD_SPLITKV_API) + else if(0 < rotary_dim) + { + std::cerr << "rotary embedding is not supported. ignoring the 'rotary_dim' option" + << std::endl; + rotary_dim = 0; + } +#endif + if(!(rotary_dim <= hdim_q)) + { + std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl; + return false; + } + else if(!(rotary_dim % 16 == 0)) + { + std::cerr << "only rotary dimensions divisible by 16 are currently supported" << std::endl; + return false; + } + ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size"); #if !CK_TILE_FMHA_FWD_SPLITKV_API if(0 < page_block_size) @@ -323,9 +357,10 @@ bool run(const ck_tile::ArgParser& arg_parser) << std::endl; use_cache_batch_idx = false; } - // the input layout we use for kvcache is same as batch mode - if((0 < seqlen_knew || 0 < rotary_dim || use_cache_batch_idx || 0 < page_block_size) && - mode != mode_enum::batch) + // the input tensor layout for kvcache is same as batch mode + const bool use_kvcache = + (0 < seqlen_knew || 0 < rotary_dim || use_cache_batch_idx || 0 < page_block_size); + if(use_kvcache && mode != mode_enum::batch) { std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl; mode = mode_enum::batch; @@ -352,11 +387,6 @@ bool run(const ck_tile::ArgParser& arg_parser) // clang-format on #endif - ck_tile::index_t hdim_q = arg_parser.get_int("d"); - ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); - if(hdim_v < 0) - hdim_v = hdim_q; - bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim @@ -420,35 +450,6 @@ bool run(const ck_tile::ArgParser& arg_parser) std::string init_method = arg_parser.get_str("init"); - ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); - if constexpr(!(std::is_same_v || - std::is_same_v)) - { - if(0 < rotary_dim) - { - std::cerr << "rotary embedding is only available for data type=fp16|bf16" << std::endl; - return false; - } - } -#if !CK_TILE_FMHA_FWD_APPENDKV_API - else if(0 < rotary_dim) - { - std::cerr << "rotary embedding is not supported. ignoring the 'rotary_dim' option" - << std::endl; - rotary_dim = 0; - } -#endif - if(!(rotary_dim <= hdim_q)) - { - std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl; - return false; - } - else if(!(rotary_dim % 16 == 0)) - { - std::cerr << "only rotary dimensions divisible by 16 are currently supported" << std::endl; - return false; - } - const bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved"); ck_tile::index_t num_splits = arg_parser.get_int("num_splits"); @@ -537,7 +538,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return false; } #if CK_TILE_FMHA_FWD_SPLITKV_API - if(0 < p_drop && (1 < num_splits || use_cache_batch_idx || 0 < page_block_size)) + if(0 < p_drop && (1 < num_splits || use_kvcache)) { std::cerr << "dropout is not supoprted by split-kv kernels. ignoring the 'p_drop' option" << std::endl; @@ -605,11 +606,11 @@ bool run(const ck_tile::ArgParser& arg_parser) generate_rotary_cos_sin(shape_seqlen_k, rotary_dim, seed); ck_tile::HostTensor lse_acc_host( - 1 < num_splits || use_cache_batch_idx || 0 < page_block_size + 1 < num_splits || use_kvcache ? std::array{num_splits, batch, nhead, max_seqlen_q} : std::array{1, 1, 1, 1}); ck_tile::HostTensor o_acc_host( - 1 < num_splits || use_cache_batch_idx || 0 < page_block_size + 1 < num_splits || use_kvcache ? std::array{num_splits, batch, nhead, max_seqlen_q, hdim_v} : std::array{1, 1, 1, 1, 1}); @@ -1034,7 +1035,7 @@ bool run(const ck_tile::ArgParser& arg_parser) const float fwd_ave_time = [&] { #if CK_TILE_FMHA_FWD_SPLITKV_API - if(1 < num_splits || use_cache_batch_idx || 0 < page_block_size) + if(1 < num_splits || use_kvcache) { fmha_fwd_splitkv_traits fmha_splitkv_traits; init_traits(fmha_splitkv_traits);