mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-07-03 13:57:04 +00:00
157 lines
6.7 KiB
CMake
157 lines
6.7 KiB
CMake
include(FetchContent)
|
|
|
|
# flash_mla
|
|
FetchContent_Declare(
|
|
repo-flashmla
|
|
GIT_REPOSITORY https://github.com/sgl-project/FlashMLA
|
|
GIT_TAG abb54777d4e08c8054c238f59889b52d4e9f0896
|
|
GIT_SHALLOW OFF
|
|
)
|
|
FetchContent_Populate(repo-flashmla)
|
|
|
|
set(FLASHMLA_CUDA_FLAGS
|
|
"--expt-relaxed-constexpr"
|
|
"--expt-extended-lambda"
|
|
"--use_fast_math"
|
|
|
|
"-Xcudafe=--diag_suppress=177" # variable was declared but never referenced
|
|
)
|
|
|
|
# The FlashMLA kernels only work on hopper and require CUDA 12.4 or later.
|
|
# Only build FlashMLA kernels if we are building for something compatible with
|
|
# sm90a
|
|
if(${CUDA_VERSION} VERSION_GREATER 12.4)
|
|
list(APPEND FLASHMLA_CUDA_FLAGS
|
|
"-gencode=arch=compute_90a,code=sm_90a"
|
|
)
|
|
endif()
|
|
if(${CUDA_VERSION} VERSION_GREATER 12.8)
|
|
list(APPEND FLASHMLA_CUDA_FLAGS
|
|
"-gencode=arch=compute_100a,code=sm_100a"
|
|
)
|
|
endif()
|
|
if(${CUDA_VERSION} VERSION_GREATER_EQUAL "13.0")
|
|
# Patch FlashMLA sources for SM103a support.
|
|
# These patches are only needed (and only valid) with CUDA 13+.
|
|
|
|
# Patch utils.h: widen IS_SM100 to cover the full SM100 family.
|
|
# Newer FlashMLA versions use csrc/utils.h.
|
|
set(FLASHMLA_UTILS_FILE "${repo-flashmla_SOURCE_DIR}/csrc/utils.h")
|
|
file(READ "${FLASHMLA_UTILS_FILE}" FLASHMLA_UTILS_CONTENT)
|
|
string(REPLACE
|
|
"#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 1000)
|
|
#define IS_SM100 1"
|
|
"#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && (__CUDA_ARCH__ < 1100)
|
|
#define IS_SM100 1"
|
|
FLASHMLA_UTILS_CONTENT "${FLASHMLA_UTILS_CONTENT}")
|
|
file(WRITE "${FLASHMLA_UTILS_FILE}" "${FLASHMLA_UTILS_CONTENT}")
|
|
message(STATUS "Patched utils.h for SM103a support")
|
|
|
|
# Patch cutlass/arch/config.h: add SM103 architecture defines.
|
|
# The new block is inserted right before the existing "// SM101 and SM101a"
|
|
# anchor in the upstream header.
|
|
set(CUTLASS_CONFIG_FILE "${repo-flashmla_SOURCE_DIR}/csrc/cutlass/include/cutlass/arch/config.h")
|
|
file(READ "${CUTLASS_CONFIG_FILE}" CUTLASS_CONFIG_CONTENT)
|
|
string(FIND "${CUTLASS_CONFIG_CONTENT}" "SM103" SM103_FOUND)
|
|
if(SM103_FOUND EQUAL -1)
|
|
string(REPLACE
|
|
"// SM101 and SM101a"
|
|
"// SM103 and SM103a
|
|
#if !CUTLASS_CLANG_CUDA && (__CUDACC_VER_MAJOR__ >= 13)
|
|
#define CUTLASS_ARCH_MMA_SM103_SUPPORTED 1
|
|
#if (!defined(CUTLASS_ARCH_MMA_SM103_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1030)
|
|
#define CUTLASS_ARCH_MMA_SM103_ENABLED 1
|
|
#if !defined(CUTLASS_ARCH_MMA_SM100A_ENABLED)
|
|
#define CUTLASS_ARCH_MMA_SM100A_ENABLED 1
|
|
#endif
|
|
#if !defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)
|
|
#define CUTLASS_ARCH_MMA_SM100F_ENABLED 1
|
|
#endif
|
|
#endif
|
|
#endif
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// SM101 and SM101a"
|
|
CUTLASS_CONFIG_CONTENT "${CUTLASS_CONFIG_CONTENT}")
|
|
file(WRITE "${CUTLASS_CONFIG_FILE}" "${CUTLASS_CONFIG_CONTENT}")
|
|
message(STATUS "Patched cutlass/arch/config.h for SM103a support")
|
|
else()
|
|
message(STATUS "cutlass/arch/config.h already patched for SM103a")
|
|
endif()
|
|
|
|
list(APPEND FLASHMLA_CUDA_FLAGS
|
|
"-gencode=arch=compute_103a,code=sm_103a"
|
|
)
|
|
endif()
|
|
|
|
|
|
set(FlashMLA_SOURCES
|
|
"csrc/flashmla_extension.cc"
|
|
|
|
# Compatibility shim for sgl-kernel torch.ops API.
|
|
${repo-flashmla_SOURCE_DIR}/csrc/python_api.cpp
|
|
|
|
# Decode metadata/combine kernels.
|
|
${repo-flashmla_SOURCE_DIR}/csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu
|
|
${repo-flashmla_SOURCE_DIR}/csrc/smxx/decode/combine/combine.cu
|
|
|
|
# sm90 dense decode.
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/instantiations/fp16.cu
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/instantiations/bf16.cu
|
|
|
|
# sm90 sparse decode.
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu
|
|
|
|
# sm90 sparse prefill.
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu
|
|
|
|
# sm100 dense prefill/bwd.
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
|
|
|
|
# sm100 sparse prefill.
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu
|
|
|
|
# sm100 sparse decode.
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm100/decode/head64/instantiations/v32.cu
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm100/decode/head64/instantiations/model1.cu
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu
|
|
|
|
${repo-flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/dense_fp8_python_api.cpp
|
|
${repo-flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu
|
|
${repo-flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_metadata.cu
|
|
)
|
|
|
|
Python_add_library(flashmla_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FlashMLA_SOURCES})
|
|
target_compile_options(flashmla_ops PRIVATE
|
|
$<$<COMPILE_LANGUAGE:CXX>:-std=c++20>
|
|
$<$<COMPILE_LANGUAGE:CUDA>:-std=c++20>
|
|
$<$<COMPILE_LANGUAGE:CUDA>:${FLASHMLA_CUDA_FLAGS}>
|
|
)
|
|
target_include_directories(flashmla_ops PRIVATE
|
|
${repo-flashmla_SOURCE_DIR}/csrc
|
|
${repo-flashmla_SOURCE_DIR}/csrc/kerutils/include
|
|
${repo-flashmla_SOURCE_DIR}/csrc/sm90
|
|
${repo-flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/
|
|
${repo-flashmla_SOURCE_DIR}/csrc/cutlass/include
|
|
${repo-flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
|
|
)
|
|
|
|
target_link_libraries(flashmla_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
|
|
|
|
install(TARGETS flashmla_ops LIBRARY DESTINATION "sgl_kernel")
|
|
|
|
target_compile_definitions(flashmla_ops PRIVATE)
|