diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index fe9d689449..fd50cfe7b8 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -88,6 +88,13 @@ else() list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0) endif() +# conditionally enable call to the fwd_appendkv API in fmha_fwd example +if ("fwd_appendkv" IN_LIST FMHA_FWD_ENABLE_APIS) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1) +else() + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0) +endif() + # Allow comparing floating points directly in order to check sentinel values list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index cc9c1f4f3a..a8dce0fd51 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -289,6 +289,13 @@ bool run(const ck_tile::ArgParser& arg_parser) arg_parser.get_str("s_kpad")); ck_tile::index_t seqlen_knew = arg_parser.get_int("s_k_new"); +#if !CK_TILE_FMHA_FWD_APPENDKV_API + if(0 < seqlen_knew) + { + std::cerr << "append-kv is not supported" << std::endl; + return false; + } +#endif #if 0 // clang-format off @@ -638,6 +645,7 @@ bool run(const ck_tile::ArgParser& arg_parser) float ave_time = 0; +#if CK_TILE_FMHA_FWD_APPENDKV_API if(0 < seqlen_knew) { auto appendkv_traits = fmha_fwd_appendkv_traits{ @@ -724,6 +732,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ave_time += fmha_fwd_appendkv(appendkv_traits, appendkv_args, stream_config); } +#endif auto fmha_traits = fmha_fwd_traits{hdim_q, hdim_v,