mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-06-30 11:48:01 +00:00
149 lines
4.6 KiB
Python
Executable File
149 lines
4.6 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import re
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
from utils import compare_versions, get_repo_root, normalize_version, validate_version
|
|
|
|
FILES_TO_UPDATE = [
|
|
Path("python/pyproject.toml"),
|
|
Path("docker/Dockerfile"),
|
|
Path("python/sglang/srt/entrypoints/engine.py"),
|
|
Path("python/sglang/srt/utils/common.py"),
|
|
]
|
|
|
|
|
|
def read_current_flashinfer_version(repo_root: Path) -> str:
|
|
"""Read the current flashinfer version from python/pyproject.toml."""
|
|
pyproject = repo_root / "python" / "pyproject.toml"
|
|
content = pyproject.read_text()
|
|
match = re.search(
|
|
r"flashinfer_python==(\d+\.\d+\.\d+(?:rc\d+|\.post\d+)?)", content
|
|
)
|
|
if not match:
|
|
raise ValueError(f"Could not find flashinfer_python version in {pyproject}")
|
|
return match.group(1)
|
|
|
|
|
|
def replace_flashinfer_version(
|
|
file_path: Path, old_version: str, new_version: str
|
|
) -> bool:
|
|
if not file_path.exists():
|
|
print(f"Warning: {file_path} does not exist, skipping")
|
|
return False
|
|
|
|
content = file_path.read_text()
|
|
new_content = content
|
|
|
|
name = file_path.name
|
|
if name == "pyproject.toml":
|
|
new_content = new_content.replace(
|
|
f"flashinfer_python=={old_version}", f"flashinfer_python=={new_version}"
|
|
)
|
|
new_content = new_content.replace(
|
|
f"flashinfer_cubin=={old_version}", f"flashinfer_cubin=={new_version}"
|
|
)
|
|
elif name == "Dockerfile":
|
|
new_content = re.sub(
|
|
rf"(ARG FLASHINFER_VERSION=){re.escape(old_version)}",
|
|
rf"\g<1>{new_version}",
|
|
new_content,
|
|
)
|
|
elif name == "engine.py":
|
|
new_content = re.sub(
|
|
r'(assert_pkg_version\(\s*"flashinfer_python",\s*)"'
|
|
+ re.escape(old_version)
|
|
+ r'"',
|
|
r'\g<1>"' + new_version + '"',
|
|
new_content,
|
|
flags=re.DOTALL,
|
|
)
|
|
elif name == "common.py":
|
|
new_content = new_content.replace(
|
|
f'e.g., "{old_version}"',
|
|
f'e.g., "{new_version}"',
|
|
)
|
|
|
|
if content == new_content:
|
|
print(f"No changes needed in {file_path}")
|
|
return False
|
|
|
|
file_path.write_text(new_content)
|
|
print(f"✓ Updated {file_path}")
|
|
return True
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Bump flashinfer version across all relevant files"
|
|
)
|
|
parser.add_argument(
|
|
"new_version",
|
|
help="New version (e.g., 0.6.4, 0.6.4rc0, or 0.6.4.post1)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
new_version = normalize_version(args.new_version)
|
|
|
|
if not validate_version(new_version):
|
|
print(f"Error: Invalid version format: {new_version}")
|
|
print("Expected format: X.Y.Z, X.Y.ZrcN, or X.Y.Z.postN")
|
|
print("Examples: 0.6.4, 0.6.4rc0, 0.6.4.post1")
|
|
sys.exit(1)
|
|
|
|
repo_root = get_repo_root()
|
|
old_version = read_current_flashinfer_version(repo_root)
|
|
print(f"Current flashinfer version: {old_version}")
|
|
print(f"New flashinfer version: {new_version}")
|
|
print()
|
|
|
|
comparison = compare_versions(new_version, old_version)
|
|
if comparison == 0:
|
|
print("Error: New version is the same as current version")
|
|
sys.exit(1)
|
|
elif comparison < 0:
|
|
print(
|
|
f"Error: New version ({new_version}) is older than current version ({old_version})"
|
|
)
|
|
print("Version must be greater than the current version")
|
|
sys.exit(1)
|
|
|
|
updated_count = 0
|
|
for file_rel in FILES_TO_UPDATE:
|
|
file_abs = repo_root / file_rel
|
|
if replace_flashinfer_version(file_abs, old_version, new_version):
|
|
updated_count += 1
|
|
|
|
print()
|
|
print(f"Successfully updated {updated_count} file(s)")
|
|
print(f"Flashinfer version bumped from {old_version} to {new_version}")
|
|
|
|
print("\nValidating version updates...")
|
|
failed_files = []
|
|
for file_rel in FILES_TO_UPDATE:
|
|
file_abs = repo_root / file_rel
|
|
if not file_abs.exists():
|
|
print(f"Warning: File {file_rel} does not exist, skipping validation.")
|
|
continue
|
|
|
|
content = file_abs.read_text()
|
|
if new_version not in content:
|
|
failed_files.append(file_rel)
|
|
print(f"✗ {file_rel} does not contain version {new_version}")
|
|
else:
|
|
print(f"✓ {file_rel} validated")
|
|
|
|
if failed_files:
|
|
print(f"\nError: {len(failed_files)} file(s) were not updated correctly:")
|
|
for file_rel in failed_files:
|
|
print(f" - {file_rel}")
|
|
sys.exit(1)
|
|
|
|
print("\nAll files validated successfully!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|