cmake_minimum_required(VERSION 3.24 FATAL_ERROR)
project(sgl_kernel LANGUAGES CXX)

# Cmake
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_SHARED_LIBRARY_PREFIX "")

set(CMAKE_COLOR_DIAGNOSTICS ON)
set(CMAKE_VERBOSE_MAKEFILE ON CACHE BOOL "ON")

# Python / Torch
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)

execute_process(
  COMMAND ${Python_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)"
  OUTPUT_VARIABLE TORCH_PY_PREFIX
  OUTPUT_STRIP_TRAILING_WHITESPACE
)

set(Torch_DIR "${TORCH_PY_PREFIX}/Torch")
list(APPEND CMAKE_PREFIX_PATH "${TORCH_PY_PREFIX}/Torch")
find_package(Torch REQUIRED)

execute_process(
  COMMAND ${Python_EXECUTABLE} -c "import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))"
  OUTPUT_VARIABLE TORCH_CXX11_ABI
  OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(TORCH_CXX11_ABI STREQUAL "0")
  add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
else()
  add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1)
endif()

# ROCm/HIP
enable_language(HIP)
list(APPEND CMAKE_PREFIX_PATH "/opt/rocm/lib/cmake/hip-lang")
find_package(hip REQUIRED CONFIG)

# Determine AMDGPU target from environment variable or default to gfx942
set(AMDGPU_TARGET_ENV "$ENV{AMDGPU_TARGET}")

if(AMDGPU_TARGET_ENV)
  # Use environment variable if specified
  set(AMDGPU_TARGETS "${AMDGPU_TARGET_ENV}")
  message(STATUS "Using AMDGPU_TARGET from environment: ${AMDGPU_TARGETS}")
else()
  # Default to gfx942 only
  set(AMDGPU_TARGETS "gfx942")
  message(STATUS "AMDGPU_TARGET not set, defaulting to gfx942")
endif()

# Set HIP architectures
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})

# FP8 macro selection
# Always define HIP_FP8_TYPE_FNUZ=1 (for gfx942 and host compilation)
# Additionally define HIP_FP8_TYPE_E4M3=1 when building for gfx950
# The existing utils.h logic will pick the right one based on architecture
set(SGL_FP8_MACROS "-DHIP_FP8_TYPE_FNUZ=1")

if(AMDGPU_TARGETS MATCHES "gfx950")
  list(APPEND SGL_FP8_MACROS "-DHIP_FP8_TYPE_E4M3=1")
  message(STATUS "Multi-arch build: Enabling both HIP_FP8_TYPE_FNUZ (gfx942) and HIP_FP8_TYPE_E4M3 (gfx950)")
elseif(AMDGPU_TARGETS MATCHES "gfx942")
  message(STATUS "Single-arch build: Enabling HIP_FP8_TYPE_FNUZ for gfx942")
else()
  message(FATAL_ERROR "Unsupported AMDGPU_TARGET '${AMDGPU_TARGETS}'. Expected 'gfx942' or 'gfx950' or both.")
endif()

# TopK dynamic smem bytes
# Dynamic shared-memory budget for the TopK kernels.
# - gfx942 (MI300/MI325): LDS is typically 64KB per workgroup -> keep dynamic smem <= ~48KB
#   (leaves room for static shared allocations in the kernel).
# - gfx95x (MI350): LDS is larger (e.g. 160KB per CU) -> allow the original 128KB dynamic smem.
if(AMDGPU_TARGET_ONE STREQUAL "gfx942")
  math(EXPR SGL_TOPK_DYNAMIC_SMEM_BYTES "48 * 1024")
else()
  math(EXPR SGL_TOPK_DYNAMIC_SMEM_BYTES "32 * 1024 * 4")
endif()

set(SGL_TOPK_MACROS "-DSGL_TOPK_DYNAMIC_SMEM_BYTES=${SGL_TOPK_DYNAMIC_SMEM_BYTES}")

# Paths / includes
set(PROJ_ROOT ${CMAKE_CURRENT_LIST_DIR})
set(SGL_INCLUDE_DIRS
  ${PROJ_ROOT}/include
  ${PROJ_ROOT}/include/impl
  ${PROJ_ROOT}/csrc
  ${TORCH_INCLUDE_DIRS}
)

# Platform-specific library directory
set(PLAT_LIB_DIR "/usr/lib/x86_64-linux-gnu")
link_directories(${PLAT_LIB_DIR})

# Sources
set(SOURCES
${PROJ_ROOT}/csrc/allreduce/custom_all_reduce.hip
${PROJ_ROOT}/csrc/allreduce/deterministic_all_reduce.hip
${PROJ_ROOT}/csrc/allreduce/quick_all_reduce.hip
${PROJ_ROOT}/csrc/common_extension_rocm.cc
${PROJ_ROOT}/csrc/elementwise/activation.hip
${PROJ_ROOT}/csrc/elementwise/pos_enc.hip
${PROJ_ROOT}/csrc/elementwise/topk.hip
${PROJ_ROOT}/csrc/grammar/apply_token_bitmask_inplace_hip.hip
${PROJ_ROOT}/csrc/kvcacheio/transfer.hip
${PROJ_ROOT}/csrc/memory/weak_ref_tensor.cpp
${PROJ_ROOT}/csrc/moe/moe_align_kernel.hip
${PROJ_ROOT}/csrc/moe/moe_topk_softmax_kernels.hip
${PROJ_ROOT}/csrc/moe/moe_topk_sigmoid_kernels.hip
${PROJ_ROOT}/csrc/speculative/eagle_utils.hip
)
set_source_files_properties(
  ${SOURCES}
  PROPERTIES
    LANGUAGE HIP
)

# Compile / Link flags
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:-O3>)

set(SGL_HIP_FLAGS
  -DNDEBUG
  -DOPERATOR_NAMESPACE=sgl_kernel
  -O3
  -std=c++17
  -DENABLE_BF16
  -DENABLE_FP8
  ${SGL_FP8_MACROS}
  -Wno-pass-failed
  -Wundefined-internal
  ${SGL_TOPK_MACROS}
)

# Python extension
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
target_include_directories(common_ops PRIVATE ${SGL_INCLUDE_DIRS})

# Apply per-language flags
target_compile_options(common_ops PRIVATE
  $<$<COMPILE_LANGUAGE:HIP>:${SGL_HIP_FLAGS}>
)

target_link_libraries(common_ops PRIVATE
  ${TORCH_LIBRARIES}
  hip::device
  hip::host
  hiprtc
  amdhip64
)

target_link_options(common_ops PRIVATE
  "SHELL:-Wl,-rpath,'\$ORIGIN/../../torch/lib'"
)

install(TARGETS common_ops
  LIBRARY DESTINATION sgl_kernel
)
