diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index e30e9e793c..241213a6ec 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -1,7 +1,27 @@ -# generate a list of kernels, but not actually emit files at config stage +# validate user-specified fmha_fwd API list +set(EXAMPLE_FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv") +set(EXAMPLE_FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING + "semicolon-separated list of APIs to generate (${EXAMPLE_FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") +if(EXAMPLE_FMHA_FWD_ENABLE_APIS STREQUAL "all") + set(EXAMPLE_FMHA_FWD_ENABLE_APIS ${EXAMPLE_FMHA_FWD_KNOWN_APIS}) +endif() + +foreach(api ${EXAMPLE_FMHA_FWD_ENABLE_APIS}) + if(NOT "${api}" IN_LIST EXAMPLE_FMHA_FWD_KNOWN_APIS) + message(FATAL_ERROR "${api} isn't a known api: ${EXAMPLE_FMHA_FWD_KNOWN_APIS}.") + endif() +endforeach() + +# "fwd" is a must-have api for the fmha_fwd example, add it if not specified +if(NOT "fwd" IN_LIST EXAMPLE_FMHA_FWD_ENABLE_APIS) + list(APPEND EXAMPLE_FMHA_FWD_ENABLE_APIS "fwd") +endif() + +string(REPLACE ";" "," EXAMPLE_FMHA_FWD_APIS "${EXAMPLE_FMHA_FWD_ENABLE_APIS}") +# generate a list of kernels, but not actually emit files at config sta execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api fwd,fwd_splitkv --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt + --api ${EXAMPLE_FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt ) execute_process( @@ -17,7 +37,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) add_custom_command( OUTPUT ${FMHA_FWD_GEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api fwd,fwd_splitkv --output_dir ${CMAKE_CURRENT_BINARY_DIR} + --api ${EXAMPLE_FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} ) add_custom_command( @@ -61,6 +81,13 @@ else() list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) endif() +# conditionally enable call to the fwd_splitkv API in fmha_fwd example +if ("fwd_splitkv" IN_LIST EXAMPLE_FMHA_FWD_ENABLE_APIS) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1) +else() + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0) +endif() + # Allow comparing floating points directly in order to check sentinel values list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal)