mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
tempsave, fmha_decode
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# validate user-specified fmha_fwd API list
|
||||
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv")
|
||||
set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
|
||||
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;fwd_decode")
|
||||
set(FMHA_FWD_ENABLE_APIS "fwd_decode" CACHE STRING
|
||||
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
|
||||
if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
|
||||
set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS})
|
||||
@@ -13,12 +13,12 @@ foreach(api ${FMHA_FWD_ENABLE_APIS})
|
||||
endforeach()
|
||||
|
||||
# "fwd" is a must-have api for the fmha_fwd example, add it if not specified
|
||||
if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
list(APPEND FMHA_FWD_ENABLE_APIS "fwd")
|
||||
endif()
|
||||
# if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
# list(APPEND FMHA_FWD_ENABLE_APIS "fwd")
|
||||
# endif()
|
||||
|
||||
# Filtering kernel
|
||||
# set(KERNEL fmha_fwd_d64_bf16_batch_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_w32x32x16_qr_async_vr_psskddv_nlogits_nbias_nmask_nlse_ndropout_nskip_nsquant)
|
||||
# set(KERNEL fmha_fwd_decode_d64_bf16_batch_b16x64x32x64x32x64_r1x4x1_r1x4x1_w16x16x32_w16x16x32_decode_qr_vr_psskddv_nlogits_nbias_nmask_nlse_ndropout_nskip_nsquant)
|
||||
|
||||
string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
|
||||
# generate a list of kernels, but not actually emit files at config sta
|
||||
@@ -106,6 +106,13 @@ else()
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0)
|
||||
endif()
|
||||
|
||||
# conditionally enable call to the fwd_appendkv API in fmha_fwd example
|
||||
if("fwd_decode" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_DECODE_API=1)
|
||||
else()
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_DECODE_API=0)
|
||||
endif()
|
||||
|
||||
# conditionally specify the use of OCP_FP8
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
|
||||
Reference in New Issue
Block a user