diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index eb344c04e0..0196caceab 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -257,9 +257,9 @@ int override_num_splits_if_necessary( template bool run(const ck_tile::ArgParser& arg_parser) { - std::string data_type = arg_parser.get_str("prec"); - int do_validation = arg_parser.get_int("v"); - + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + auto mode = static_cast(arg_parser.get_uint32("mode")); ck_tile::index_t batch = arg_parser.get_int("b"); ck_tile::index_t nhead = arg_parser.get_int("h"); ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); @@ -323,11 +323,11 @@ bool run(const ck_tile::ArgParser& arg_parser) << std::endl; use_cache_batch_idx = false; } - - auto mode = static_cast(arg_parser.get_uint32("mode")); - if((use_cache_batch_idx || 0 < page_block_size) && mode != mode_enum::batch) + // the input layout we use for kvcache is same as batch mode + if((0 < seqlen_knew || 0 < rotary_dim || use_cache_batch_idx || 0 < page_block_size) && + mode != mode_enum::batch) { - std::cerr << "both kvcache & split-kv enabled. ignoring the 'mode' option" << std::endl; + std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl; mode = mode_enum::batch; }