Files
cutlass/python/CuTeDSL/prep_editable_install.py
Junkai-Wu 39b352fa93 v4.6 dev update. (#3315)
* v4.6 dev update.

* Remove CUTLASS_HOST_DEVICE from CudaHostAdapater::memsetDevice (#3286)

* [SM120] Add ptr-array TMA collective for tensor/token-scaled FP8 grouped GEMM (#3280)

* gemm: add SM120 array TMA collective for tensor/token-scaled FP8 grouped GEMM

Adds CollectiveMma and CollectiveBuilder specializations for
MainloopSm120ArrayTmaWarpSpecialized, enabling ptr-array grouped GEMM
(MoE expert dispatch) with tensor- and token-level FP8 scaling on
SM_120/SM_121 consumer Blackwell (RTX 5090/5080/5070, DGX Spark GB10).

New files:
- include/cutlass/gemm/collective/sm120_mma_array_tma.hpp
  CollectiveMma specialization for MainloopSm120ArrayTmaWarpSpecialized.
  Handles both Cooperative (4x2 atom layout) and Pingpong (2x2) schedules.
  Grouped GEMM via pointer-array indirection through params.ptr_A / ptr_B.
  Supports F8F6F4 MMA with TMA loads for both A and B operands.

- include/cutlass/gemm/collective/builders/sm120_array_mma_builder.inl
  CollectiveBuilder specialization for KernelPtrArrayTmaWarpSpecialized
  Cooperative/PingpongSm120<N> schedule tags. Computes tile/stage counts
  from smem capacity, routes to MainloopSm120ArrayTmaWarpSpecialized
  dispatch policy, produces correctly-typed CollectiveOp.

Modified files:
- collective_mma.hpp: include sm120_mma_array_tma.hpp
- collective_builder.hpp: include sm120_array_mma_builder.inl
- sm120_mma_builder.inl: remove ptr-array schedules from enable_if
  (they now route to sm120_array_mma_builder.inl) and drop the
  IsPtrArrayKernel static_assert that enforced the restriction

Validated on real SM_121 hardware (DGX Spark, 128 GB LPDDR5X) running
vLLM with RedHatAI/gemma-4-26B-A4B-it-FP8-Dynamic (Gemma 4 MoE, 26B
total / 4B active). Previously fell back to a non-CUTLASS Triton path;
with this patch, the SM120 CUTLASS grouped GEMM collective activates and
produces correct outputs. Short-sequence throughput improved ~7% vs the
fallback baseline (76.3 → 81.9 tok/s).

Closes #3263

Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Tyler Merritt <tgmerritt@gmail.com>

* test: add SM120 ptr-array grouped GEMM unit tests

Adds 6 device-level tests for the CollectiveMma/CollectiveBuilder
specializations introduced for MainloopSm120ArrayTmaWarpSpecialized,
covering both KernelPtrArrayTmaWarpSpecializedPingpongSm120<2> and
KernelPtrArrayTmaWarpSpecializedCooperativeSm120<2> schedule tags across
e4m3×e4m3 (symmetric), e4m3×e5m2 (mixed), float and bfloat16 outputs,
and two tile shapes.

Tests land in test/unit/gemm/device/sm120_tensorop_gemm/ under the new
cutlass_test_unit_sm120_grouped_gemm_device_tensorop CMake target, per
reviewer request in PR #3280.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Signed-off-by: Tyler Merritt <tgmerritt@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>

---------

Signed-off-by: Tyler Merritt <tgmerritt@gmail.com>
Co-authored-by: Alex Georgiev <89279829+alexngUNC@users.noreply.github.com>
Co-authored-by: Tyler <tgmerritt@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>
2026-06-15 23:23:20 -04:00

311 lines
9.7 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
"""
CuTeDSL Development Package Setup
This setup script automatically downloads the nvidia-cutlass-dsl wheel,
extracts required libraries and Python packages, and sets up the development
environment for CuTeDSL.
"""
import subprocess
import sys
import shutil
import tempfile
import zipfile
import re
from pathlib import Path
from typing import Tuple
import logging
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger = logging.getLogger(__name__)
# Constants
PACKAGE_NAME = "nvidia-cutlass-dsl"
class CutlassDSLSetupError(Exception):
"""Custom exception for setup errors."""
pass
def download_wheel(temp_dir: Path) -> Path:
"""
Download the nvidia-cutlass-dsl wheel to a temporary directory.
Args:
temp_dir: Temporary directory path for downloading
Returns:
Path to the downloaded wheel file
Raises:
CutlassDSLSetupError: If download fails or wheel not found
"""
logger.info(f"Downloading {PACKAGE_NAME} wheel to {temp_dir}")
try:
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"download",
"--no-deps",
PACKAGE_NAME,
"--dest",
str(temp_dir),
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
except subprocess.CalledProcessError as e:
error_msg = f"Failed to download {PACKAGE_NAME}: {e}"
if e.stdout:
error_msg += f"\nstdout: {e.stdout.decode()}"
if e.stderr:
error_msg += f"\nstderr: {e.stderr.decode()}"
raise CutlassDSLSetupError(error_msg)
# Find the downloaded wheel file
wheel_pattern = f"{PACKAGE_NAME.replace('-', '_')}-*.whl"
wheel_files = list(temp_dir.glob(wheel_pattern))
if not wheel_files:
raise CutlassDSLSetupError(
f"No wheel file matching {wheel_pattern} found after download"
)
wheel_path = wheel_files[0]
logger.info(f"Successfully downloaded: {wheel_path.name}")
return wheel_path
def extract_version_from_wheel(wheel_path: Path) -> str:
"""
Extract version from wheel filename and convert to dev version.
Args:
wheel_path: Path to the wheel file
Returns:
Version string in format '{version}.dev0' or '{base_version}.dev{n+1}' if already has dev{n}
Raises:
CutlassDSLSetupError: If version cannot be extracted from filename
"""
wheel_filename = wheel_path.name
# Construct version regex from package name
# Wheel filename format: {package_name_with_underscores}-{version}-{python}-{abi}-{platform}.whl
package_pattern = PACKAGE_NAME.replace("-", "_")
version_regex = rf"{re.escape(package_pattern)}-([^-]+)-"
version_match = re.match(version_regex, wheel_filename)
if version_match:
version = version_match.group(1)
# Check if version already has .dev<n> pattern
dev_pattern = r"^(.+)\.dev(\d+)"
dev_match = re.match(dev_pattern, version)
if dev_match:
base_version = dev_match.group(1)
dev_number = int(dev_match.group(2))
new_dev_number = dev_number + 1
dev_version = f"{base_version}.dev{new_dev_number}"
logger.info(
f"Detected version with dev{dev_number}: {version} -> using {dev_version}"
)
else:
dev_version = f"{version}.dev0"
logger.info(f"Detected version: {version} -> using {dev_version}")
return dev_version
else:
raise CutlassDSLSetupError(
f"Could not parse version from wheel filename: {wheel_filename}"
)
def extract_wheel_contents(wheel_path: Path, extract_dir: Path) -> None:
"""
Extract wheel contents to specified directory.
Args:
wheel_path: Path to the wheel file
extract_dir: Directory to extract contents to
Raises:
CutlassDSLSetupError: If extraction fails
"""
logger.info(f"Extracting wheel contents to {extract_dir}")
try:
with zipfile.ZipFile(wheel_path, "r") as wheel_zip:
wheel_zip.extractall(extract_dir)
logger.info("Wheel extraction completed successfully")
except zipfile.BadZipFile as e:
raise CutlassDSLSetupError(f"Invalid wheel file {wheel_path}: {e}")
except Exception as e:
raise CutlassDSLSetupError(f"Failed to extract wheel: {e}")
def copy_library_files(extract_dir: Path, package_root: Path) -> int:
"""
Copy .so library files from extracted wheel to package lib directory.
Args:
extract_dir: Directory containing extracted wheel contents
package_root: Root directory of the package
Returns:
Number of files copied
"""
extract_dir / "**" / "lib" / "*.so"
so_files = [f for f in extract_dir.rglob("lib/*.so")]
if not so_files:
logger.warning("No .so files found in the wheel")
return 0
logger.info(f"Found {len(so_files)} .so files")
# Create lib directory
lib_dir = package_root / "lib"
lib_dir.mkdir(exist_ok=True)
# Copy .so files
copied_count = 0
for so_file in so_files:
dest_path = lib_dir / so_file.name
logger.info(f"Copying {so_file.name} to {dest_path}")
shutil.copy2(so_file, dest_path)
copied_count += 1
logger.info(f"Successfully copied {copied_count} .so files to lib/")
return copied_count
def copy_python_packages(extract_dir: Path, package_root: Path) -> Tuple[int, int]:
"""
Copy python_packages/cutlass/ directory to local cutlass/ directory.
Ignores conflicts with existing files.
Args:
extract_dir: Directory containing extracted wheel contents
package_root: Root directory of the package
Returns:
Tuple of (files_copied, files_skipped)
"""
# Find source cutlass directory
cutlass_source_dirs = list(extract_dir.rglob("python_packages/cutlass"))
if not cutlass_source_dirs:
logger.warning("No python_packages/cutlass/ directory found in the wheel")
return 0, 0
cutlass_source_dir = cutlass_source_dirs[0]
cutlass_dest_dir = package_root / "cutlass"
logger.info("Found python_packages/cutlass/ directory")
logger.info(f"Copying from {cutlass_source_dir} to {cutlass_dest_dir}")
copied_count = 0
skipped_count = 0
# Walk through source directory
for src_file in cutlass_source_dir.rglob("*"):
if src_file.is_file():
# Calculate relative path and destination
rel_path = src_file.relative_to(cutlass_source_dir)
dest_file = cutlass_dest_dir / rel_path
# Create parent directories
dest_file.parent.mkdir(parents=True, exist_ok=True)
# Copy file if it doesn't exist
if dest_file.exists():
skipped_count += 1
logger.debug(f" Skipping {rel_path} (already exists)")
else:
shutil.copy2(src_file, dest_file)
copied_count += 1
logger.info(f" Copied {rel_path}")
logger.info(
f"Cutlass directory update: {copied_count} files copied, {skipped_count} files skipped"
)
return copied_count, skipped_count
def write_version_file(version: str, package_root: Path) -> None:
"""
Write version string to VERSION file in the package root directory.
Args:
version: Version string to write
package_root: Root directory of the package
"""
version_file = package_root / "VERSION.EDITABLE"
logger.info(f"Writing version {version} to {version_file}")
try:
with open(version_file, "w", encoding="utf-8") as f:
f.write(version + "\n")
logger.info(f"Successfully created VERSION file with version: {version}")
except Exception as e:
raise CutlassDSLSetupError(f"Failed to write VERSION file: {e}")
def prep_editable_install() -> None:
"""
Set up the CuTeDSL development environment.
Downloads nvidia-cutlass-dsl wheel, extracts version, and copies required files.
Raises:
CutlassDSLSetupError: If setup fails
"""
package_root = Path(__file__).parent
with tempfile.TemporaryDirectory() as temp_dir_str:
temp_dir = Path(temp_dir_str)
extract_dir = temp_dir / "extracted"
# Download and extract wheel
wheel_path = download_wheel(temp_dir)
version = extract_version_from_wheel(wheel_path)
extract_wheel_contents(wheel_path, extract_dir)
# Copy files
lib_files_copied = copy_library_files(extract_dir, package_root)
py_files_copied, py_files_skipped = copy_python_packages(
extract_dir, package_root
)
# Write version file
write_version_file(version, package_root)
logger.info("Setup completed successfully!")
logger.info(
f"Summary: {lib_files_copied} lib files, "
f"{py_files_copied} Python files copied, "
f"{py_files_skipped} Python files skipped"
)
logger.info(f"Detected upstream version: {version}")
if __name__ == "__main__":
prep_editable_install()