mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
Fixed editable install to depend on CuTeDSL/requirements.txt (#2768)
To guarantee wheel version alignment of the source code.
This commit is contained in:
@@ -40,6 +40,27 @@ class CutlassDSLSetupError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_package_spec(requirements_path: Optional[Path] = None) -> str:
|
||||
"""
|
||||
Return the pip requirement spec for nvidia-cutlass-dsl from requirements.txt.
|
||||
|
||||
If anything goes wrong (file not found, parse failure, line missing),
|
||||
return PACKAGE_NAME as a safe default.
|
||||
"""
|
||||
try:
|
||||
req_path = requirements_path or Path(__file__).with_name("requirements.txt")
|
||||
with open(req_path, "r", encoding="utf-8") as f:
|
||||
for raw_line in f:
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
if line.lower().startswith(PACKAGE_NAME):
|
||||
return line.split("#", 1)[0].strip()
|
||||
except Exception:
|
||||
pass
|
||||
return PACKAGE_NAME
|
||||
|
||||
|
||||
def download_wheel(temp_dir: Path) -> Path:
|
||||
"""
|
||||
Download the nvidia-cutlass-dsl wheel to a temporary directory.
|
||||
@@ -53,7 +74,10 @@ def download_wheel(temp_dir: Path) -> Path:
|
||||
Raises:
|
||||
CutlassDSLSetupError: If download fails or wheel not found
|
||||
"""
|
||||
logger.info(f"Downloading {PACKAGE_NAME} wheel to {temp_dir}")
|
||||
# Resolve package spec from requirements, or fall back to PACKAGE_NAME
|
||||
package_spec = get_package_spec()
|
||||
|
||||
logger.info(f"Downloading {package_spec} wheel to {temp_dir}")
|
||||
|
||||
try:
|
||||
subprocess.check_call(
|
||||
@@ -63,7 +87,7 @@ def download_wheel(temp_dir: Path) -> Path:
|
||||
"pip",
|
||||
"download",
|
||||
"--no-deps",
|
||||
PACKAGE_NAME,
|
||||
package_spec,
|
||||
"--dest",
|
||||
str(temp_dir),
|
||||
],
|
||||
@@ -79,7 +103,7 @@ def download_wheel(temp_dir: Path) -> Path:
|
||||
raise CutlassDSLSetupError(error_msg)
|
||||
|
||||
# Find the downloaded wheel file
|
||||
wheel_pattern = f"{PACKAGE_NAME.replace('-', '_')}-*.whl"
|
||||
wheel_pattern = f"*.whl"
|
||||
wheel_files = list(temp_dir.glob(wheel_pattern))
|
||||
if not wheel_files:
|
||||
raise CutlassDSLSetupError(
|
||||
@@ -108,7 +132,7 @@ def extract_version_from_wheel(wheel_path: Path) -> str:
|
||||
# Construct version regex from package name
|
||||
# Wheel filename format: {package_name_with_underscores}-{version}-{python}-{abi}-{platform}.whl
|
||||
package_pattern = PACKAGE_NAME.replace("-", "_")
|
||||
version_regex = rf"{re.escape(package_pattern)}-([^-]+)-"
|
||||
version_regex = rf"{re.escape(package_pattern)}-([^-]+)"
|
||||
version_match = re.match(version_regex, wheel_filename)
|
||||
|
||||
if version_match:
|
||||
@@ -132,10 +156,7 @@ def extract_version_from_wheel(wheel_path: Path) -> str:
|
||||
|
||||
return dev_version
|
||||
else:
|
||||
raise CutlassDSLSetupError(
|
||||
f"Could not parse version from wheel filename: {wheel_filename}"
|
||||
)
|
||||
|
||||
return "9.9.9.dev0"
|
||||
|
||||
def extract_wheel_contents(wheel_path: Path, extract_dir: Path) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user