tempsave, fmha_decode

This commit is contained in:
aska-0096
2025-07-08 08:37:20 +00:00
parent 47565f21a5
commit 18686cfe5b
13 changed files with 2562 additions and 71 deletions

View File

@@ -1,6 +1,6 @@
# validate user-specified fmha_fwd API list
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv")
set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;fwd_decode")
set(FMHA_FWD_ENABLE_APIS "fwd_decode" CACHE STRING
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS})
@@ -13,12 +13,12 @@ foreach(api ${FMHA_FWD_ENABLE_APIS})
endforeach()
# "fwd" is a must-have api for the fmha_fwd example, add it if not specified
if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND FMHA_FWD_ENABLE_APIS "fwd")
endif()
# if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS)
# list(APPEND FMHA_FWD_ENABLE_APIS "fwd")
# endif()
# Filtering kernel
# set(KERNEL fmha_fwd_d64_bf16_batch_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_w32x32x16_qr_async_vr_psskddv_nlogits_nbias_nmask_nlse_ndropout_nskip_nsquant)
# set(KERNEL fmha_fwd_decode_d64_bf16_batch_b16x64x32x64x32x64_r1x4x1_r1x4x1_w16x16x32_w16x16x32_decode_qr_vr_psskddv_nlogits_nbias_nmask_nlse_ndropout_nskip_nsquant)
string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
# generate a list of kernels, but not actually emit files at config sta
@@ -106,6 +106,13 @@ else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0)
endif()
# conditionally enable call to the fwd_appendkv API in fmha_fwd example
if("fwd_decode" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_DECODE_API=1)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_DECODE_API=0)
endif()
# conditionally specify the use of OCP_FP8
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)