mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-06-30 19:57:52 +00:00
208 lines
6.1 KiB
Python
208 lines
6.1 KiB
Python
# Copyright 2025 SGLang Team. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
import os
|
|
import platform
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# isort: off
|
|
import torch
|
|
import torchada # noqa: F401
|
|
|
|
# isort: on
|
|
from setuptools import find_packages, setup
|
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
|
|
|
root = Path(__file__).parent.resolve()
|
|
third_party = Path(os.environ.get("SGLANG_MUSA_THIRD_PARTY_DIR", "build/_deps"))
|
|
arch = platform.machine().lower()
|
|
|
|
|
|
class _RepoInfo:
|
|
"""Configuration for a third-party git repository."""
|
|
|
|
def __init__(self, name, git_repository, git_tag, git_shallow=False):
|
|
self.name = name
|
|
self.git_repository = git_repository
|
|
self.git_tag = git_tag
|
|
self.git_shallow = git_shallow
|
|
self.source_dir = third_party / name
|
|
|
|
|
|
_FLASHINFER_REPO = _RepoInfo(
|
|
name="flashinfer",
|
|
git_repository="https://github.com/flashinfer-ai/flashinfer.git",
|
|
git_tag="bc29697ba20b7e6bdb728ded98f04788e16ee021",
|
|
git_shallow=False,
|
|
)
|
|
|
|
_MUTLASS_REPO = _RepoInfo(
|
|
name="mutlass",
|
|
git_repository="https://github.com/MooreThreads/mutlass.git",
|
|
git_tag="3abd6a728aacd190df0d922514aca8a8bc3c46b7",
|
|
git_shallow=False,
|
|
)
|
|
|
|
|
|
def _get_version():
|
|
with open(root / "pyproject.toml") as f:
|
|
for line in f:
|
|
if line.startswith("version"):
|
|
return line.split("=")[1].strip().strip('"')
|
|
|
|
|
|
operator_namespace = "sgl_kernel"
|
|
include_dirs = [
|
|
root / "include",
|
|
root / "include" / "impl",
|
|
root / "csrc",
|
|
root / _FLASHINFER_REPO.source_dir / "include",
|
|
root / _FLASHINFER_REPO.source_dir / "csrc",
|
|
root / _MUTLASS_REPO.source_dir / "include",
|
|
]
|
|
|
|
sources = [
|
|
"csrc/common_extension_musa.cc",
|
|
str(_FLASHINFER_REPO.source_dir / "csrc/norm.cu"),
|
|
str(_FLASHINFER_REPO.source_dir / "csrc/renorm.cu"),
|
|
str(_FLASHINFER_REPO.source_dir / "csrc/sampling.cu"),
|
|
]
|
|
|
|
cxx_flags = ["force_mcc"]
|
|
libraries = ["c10", "torch", "torch_python"]
|
|
extra_link_args = [
|
|
"-Wl,-rpath,$ORIGIN/../../torch/lib",
|
|
f"-L/usr/lib/{arch}-linux-gnu",
|
|
"-lmublasLt",
|
|
]
|
|
|
|
default_target = "mp_31"
|
|
mtgpu_target = os.environ.get("MTGPU_TARGET", default_target)
|
|
|
|
if torch.musa.is_available():
|
|
try:
|
|
prop = torch.musa.get_device_properties(0)
|
|
mtgpu_target = f"mp_{prop.major}{prop.minor}"
|
|
except Exception as e:
|
|
print(f"Warning: Failed to detect GPU properties: {e}")
|
|
else:
|
|
print(f"Warning: torch.musa not available. Using default target: {mtgpu_target}")
|
|
|
|
if mtgpu_target not in ["mp_22", "mp_31"]:
|
|
print(
|
|
f"Warning: Unsupported GPU architecture detected '{mtgpu_target}'. Expected 'mp_22' or 'mp_31'."
|
|
)
|
|
sys.exit(1)
|
|
|
|
mcc_flags = [
|
|
"-DNDEBUG",
|
|
f"-DOPERATOR_NAMESPACE={operator_namespace}",
|
|
"-O3",
|
|
"-fPIC",
|
|
"-std=c++17",
|
|
f"--cuda-gpu-arch={mtgpu_target}",
|
|
"-x",
|
|
"musa",
|
|
"-mtgpu",
|
|
"-Od3",
|
|
"-ffast-math",
|
|
"-fmusa-flush-denormals-to-zero",
|
|
"-fno-strict-aliasing",
|
|
"-DUSE_MUSA",
|
|
"-DENABLE_BF16",
|
|
"-DFLASHINFER_ENABLE_F16",
|
|
"-DFLASHINFER_ENABLE_BF16",
|
|
]
|
|
|
|
if mtgpu_target == "mp_31":
|
|
mcc_flags.extend(
|
|
[
|
|
"-DENABLE_FP8",
|
|
"-DFLASHINFER_ENABLE_FP8",
|
|
"-DFLASHINFER_ENABLE_FP8_E4M3",
|
|
"-DFLASHINFER_ENABLE_FP8_E5M2",
|
|
]
|
|
)
|
|
|
|
ext_modules = [
|
|
CUDAExtension(
|
|
name="sgl_kernel.common_ops",
|
|
sources=sources,
|
|
include_dirs=include_dirs,
|
|
extra_compile_args={
|
|
"mcc": mcc_flags,
|
|
"cxx": cxx_flags,
|
|
},
|
|
libraries=libraries,
|
|
extra_link_args=extra_link_args,
|
|
py_limited_api=False,
|
|
),
|
|
]
|
|
|
|
|
|
class _CustomBuildExt(BuildExtension):
|
|
"""Custom build extension that clones third-party repositories before building."""
|
|
|
|
@staticmethod
|
|
def _clone_and_checkout(repo_path, repo_url, git_tag, git_shallow):
|
|
"""Clone a git repository and checkout a specific tag/commit."""
|
|
repo_path.parent.mkdir(parents=True, exist_ok=True)
|
|
if not repo_path.exists():
|
|
clone_cmd = ["git", "clone"]
|
|
if git_shallow:
|
|
clone_cmd += ["--depth", "1"]
|
|
clone_cmd += [repo_url, str(repo_path)]
|
|
subprocess.check_call(clone_cmd)
|
|
subprocess.check_call(["git", "checkout", git_tag], cwd=repo_path)
|
|
else:
|
|
subprocess.check_call(["git", "fetch", "--all"], cwd=repo_path)
|
|
subprocess.check_call(["git", "checkout", git_tag], cwd=repo_path)
|
|
|
|
def run(self):
|
|
if os.environ.get("SGLANG_MUSA_SKIP_THIRD_PARTY", "0") == "1":
|
|
print(
|
|
"Skipping third-party repositories cloning (SGLANG_MUSA_SKIP_THIRD_PARTY=1)"
|
|
)
|
|
else:
|
|
print("Cloning third-party repositories...")
|
|
self._clone_and_checkout(
|
|
_MUTLASS_REPO.source_dir,
|
|
_MUTLASS_REPO.git_repository,
|
|
_MUTLASS_REPO.git_tag,
|
|
_MUTLASS_REPO.git_shallow,
|
|
)
|
|
self._clone_and_checkout(
|
|
_FLASHINFER_REPO.source_dir,
|
|
_FLASHINFER_REPO.git_repository,
|
|
_FLASHINFER_REPO.git_tag,
|
|
_FLASHINFER_REPO.git_shallow,
|
|
)
|
|
print("Third-party repositories ready.")
|
|
|
|
super().run()
|
|
|
|
|
|
setup(
|
|
name="sglang-kernel",
|
|
version=_get_version(),
|
|
packages=find_packages(where="python"),
|
|
package_dir={"": "python"},
|
|
ext_modules=ext_modules,
|
|
cmdclass={"build_ext": _CustomBuildExt.with_options(use_ninja=True)},
|
|
options={"bdist_wheel": {"py_limited_api": "cp39"}},
|
|
)
|