Fixed editable install to depend on CuTeDSL/requirements.txt (#2768)

To guarantee wheel version alignment of the source code.
This commit is contained in:
Zekun Fan
2025-11-14 15:31:49 -08:00
committed by GitHub
parent bd96096d58
commit a2439551c7

View File

@@ -40,6 +40,27 @@ class CutlassDSLSetupError(Exception):
pass
def get_package_spec(requirements_path: Optional[Path] = None) -> str:
"""
Return the pip requirement spec for nvidia-cutlass-dsl from requirements.txt.
If anything goes wrong (file not found, parse failure, line missing),
return PACKAGE_NAME as a safe default.
"""
try:
req_path = requirements_path or Path(__file__).with_name("requirements.txt")
with open(req_path, "r", encoding="utf-8") as f:
for raw_line in f:
line = raw_line.strip()
if not line or line.startswith("#"):
continue
if line.lower().startswith(PACKAGE_NAME):
return line.split("#", 1)[0].strip()
except Exception:
pass
return PACKAGE_NAME
def download_wheel(temp_dir: Path) -> Path:
"""
Download the nvidia-cutlass-dsl wheel to a temporary directory.
@@ -53,7 +74,10 @@ def download_wheel(temp_dir: Path) -> Path:
Raises:
CutlassDSLSetupError: If download fails or wheel not found
"""
logger.info(f"Downloading {PACKAGE_NAME} wheel to {temp_dir}")
# Resolve package spec from requirements, or fall back to PACKAGE_NAME
package_spec = get_package_spec()
logger.info(f"Downloading {package_spec} wheel to {temp_dir}")
try:
subprocess.check_call(
@@ -63,7 +87,7 @@ def download_wheel(temp_dir: Path) -> Path:
"pip",
"download",
"--no-deps",
PACKAGE_NAME,
package_spec,
"--dest",
str(temp_dir),
],
@@ -79,7 +103,7 @@ def download_wheel(temp_dir: Path) -> Path:
raise CutlassDSLSetupError(error_msg)
# Find the downloaded wheel file
wheel_pattern = f"{PACKAGE_NAME.replace('-', '_')}-*.whl"
wheel_pattern = f"*.whl"
wheel_files = list(temp_dir.glob(wheel_pattern))
if not wheel_files:
raise CutlassDSLSetupError(
@@ -108,7 +132,7 @@ def extract_version_from_wheel(wheel_path: Path) -> str:
# Construct version regex from package name
# Wheel filename format: {package_name_with_underscores}-{version}-{python}-{abi}-{platform}.whl
package_pattern = PACKAGE_NAME.replace("-", "_")
version_regex = rf"{re.escape(package_pattern)}-([^-]+)-"
version_regex = rf"{re.escape(package_pattern)}-([^-]+)"
version_match = re.match(version_regex, wheel_filename)
if version_match:
@@ -132,10 +156,7 @@ def extract_version_from_wheel(wheel_path: Path) -> str:
return dev_version
else:
raise CutlassDSLSetupError(
f"Could not parse version from wheel filename: {wheel_filename}"
)
return "9.9.9.dev0"
def extract_wheel_contents(wheel_path: Path, extract_dir: Path) -> None:
"""