mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-06-29 19:07:07 +00:00
* v4.6 dev update. * Remove CUTLASS_HOST_DEVICE from CudaHostAdapater::memsetDevice (#3286) * [SM120] Add ptr-array TMA collective for tensor/token-scaled FP8 grouped GEMM (#3280) * gemm: add SM120 array TMA collective for tensor/token-scaled FP8 grouped GEMM Adds CollectiveMma and CollectiveBuilder specializations for MainloopSm120ArrayTmaWarpSpecialized, enabling ptr-array grouped GEMM (MoE expert dispatch) with tensor- and token-level FP8 scaling on SM_120/SM_121 consumer Blackwell (RTX 5090/5080/5070, DGX Spark GB10). New files: - include/cutlass/gemm/collective/sm120_mma_array_tma.hpp CollectiveMma specialization for MainloopSm120ArrayTmaWarpSpecialized. Handles both Cooperative (4x2 atom layout) and Pingpong (2x2) schedules. Grouped GEMM via pointer-array indirection through params.ptr_A / ptr_B. Supports F8F6F4 MMA with TMA loads for both A and B operands. - include/cutlass/gemm/collective/builders/sm120_array_mma_builder.inl CollectiveBuilder specialization for KernelPtrArrayTmaWarpSpecialized Cooperative/PingpongSm120<N> schedule tags. Computes tile/stage counts from smem capacity, routes to MainloopSm120ArrayTmaWarpSpecialized dispatch policy, produces correctly-typed CollectiveOp. Modified files: - collective_mma.hpp: include sm120_mma_array_tma.hpp - collective_builder.hpp: include sm120_array_mma_builder.inl - sm120_mma_builder.inl: remove ptr-array schedules from enable_if (they now route to sm120_array_mma_builder.inl) and drop the IsPtrArrayKernel static_assert that enforced the restriction Validated on real SM_121 hardware (DGX Spark, 128 GB LPDDR5X) running vLLM with RedHatAI/gemma-4-26B-A4B-it-FP8-Dynamic (Gemma 4 MoE, 26B total / 4B active). Previously fell back to a non-CUTLASS Triton path; with this patch, the SM120 CUTLASS grouped GEMM collective activates and produces correct outputs. Short-sequence throughput improved ~7% vs the fallback baseline (76.3 → 81.9 tok/s). Closes #3263 Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Tyler Merritt <tgmerritt@gmail.com> * test: add SM120 ptr-array grouped GEMM unit tests Adds 6 device-level tests for the CollectiveMma/CollectiveBuilder specializations introduced for MainloopSm120ArrayTmaWarpSpecialized, covering both KernelPtrArrayTmaWarpSpecializedPingpongSm120<2> and KernelPtrArrayTmaWarpSpecializedCooperativeSm120<2> schedule tags across e4m3×e4m3 (symmetric), e4m3×e5m2 (mixed), float and bfloat16 outputs, and two tile shapes. Tests land in test/unit/gemm/device/sm120_tensorop_gemm/ under the new cutlass_test_unit_sm120_grouped_gemm_device_tensorop CMake target, per reviewer request in PR #3280. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Signed-off-by: Tyler Merritt <tgmerritt@gmail.com> Co-authored-by: Claude <noreply@anthropic.com> --------- Signed-off-by: Tyler Merritt <tgmerritt@gmail.com> Co-authored-by: Alex Georgiev <89279829+alexngUNC@users.noreply.github.com> Co-authored-by: Tyler <tgmerritt@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
214 lines
7.7 KiB
Python
214 lines
7.7 KiB
Python
#################################################################################################
|
|
#
|
|
# Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
#
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
#
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
#
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
#
|
|
#################################################################################################
|
|
import logging
|
|
import os
|
|
import sys
|
|
|
|
import cutlass_library
|
|
|
|
|
|
def _cuda_install_path_from_nvcc() -> str:
|
|
import subprocess
|
|
# Attempt to detect CUDA_INSTALL_PATH based on location of NVCC
|
|
result = subprocess.run(['/usr/bin/which', 'nvcc'], capture_output=True)
|
|
if result.returncode != 0:
|
|
raise Exception(f'Unable to find nvcc via `which` utility.')
|
|
|
|
cuda_install_path = result.stdout.decode('utf-8').split('/bin/nvcc')[0]
|
|
if not os.path.isdir(cuda_install_path):
|
|
raise Exception(f'Environment variable "CUDA_INSTALL_PATH" is not defined, '
|
|
f'and default path of {cuda_install_path} does not exist.')
|
|
|
|
return cuda_install_path
|
|
|
|
|
|
CUTLASS_PATH = os.getenv("CUTLASS_PATH", cutlass_library.source_path)
|
|
|
|
# Alias CUTLASS_PATH as source_path
|
|
source_path = CUTLASS_PATH
|
|
|
|
_NVCC_VERSION = None
|
|
def nvcc_version():
|
|
global _NVCC_VERSION
|
|
if _NVCC_VERSION is None:
|
|
import subprocess
|
|
|
|
# Attempt to get NVCC version
|
|
result = subprocess.run(['nvcc', '--version'], capture_output=True)
|
|
if result.returncode != 0:
|
|
raise Exception('Unable to run `nvcc --version')
|
|
_NVCC_VERSION = str(result.stdout).split(" release ")[-1].split(",")[0]
|
|
return _NVCC_VERSION
|
|
|
|
_CUDA_INSTALL_PATH = None
|
|
def cuda_install_path():
|
|
"""
|
|
Helper method for on-demand fetching of the CUDA installation path. This allows
|
|
the import of CUTLASS to proceed even if NVCC is not available, preferring to
|
|
raise this error only when an operation that needs NVCC is being performed.
|
|
"""
|
|
global _CUDA_INSTALL_PATH
|
|
if _CUDA_INSTALL_PATH is None:
|
|
_CUDA_INSTALL_PATH = os.getenv("CUDA_INSTALL_PATH", _cuda_install_path_from_nvcc())
|
|
return _CUDA_INSTALL_PATH
|
|
|
|
CACHE_FILE = "compiled_cache.db"
|
|
|
|
from cutlass_library import (
|
|
DataType,
|
|
EpilogueScheduleType,
|
|
KernelScheduleType,
|
|
MathOperation,
|
|
LayoutType,
|
|
OpcodeClass,
|
|
TileDescription,
|
|
TileSchedulerType,
|
|
)
|
|
|
|
this = sys.modules[__name__]
|
|
this.logger = logging.getLogger(__name__)
|
|
|
|
# RMM is only supported for Python 3.9+
|
|
if (sys.version_info.major == 3 and sys.version_info.minor > 8) or sys.version_info.major > 3:
|
|
try:
|
|
import rmm
|
|
this.use_rmm = True
|
|
except ImportError:
|
|
this.use_rmm = False
|
|
else:
|
|
this.use_rmm = False
|
|
|
|
|
|
def set_log_level(level: int):
|
|
"""
|
|
Sets the log level
|
|
|
|
:param log_level: severity of logging level to use. See https://docs.python.org/3/library/logging.html#logging-levels for options
|
|
:type log_level: int
|
|
"""
|
|
this.logger.setLevel(level)
|
|
|
|
set_log_level(logging.ERROR)
|
|
|
|
from cutlass_cppgen.library_defaults import OptionRegistry
|
|
from cutlass_cppgen.backend.utils.device import device_cc
|
|
|
|
this._option_registry = None
|
|
def get_option_registry():
|
|
"""
|
|
Helper method for on-demand initialization of the options registry. This avoids building
|
|
the registry when CUTLASS is imported.
|
|
"""
|
|
if this._option_registry is None:
|
|
this.logger.info("Initializing option registry")
|
|
this._option_registry = OptionRegistry(device_cc())
|
|
return this._option_registry
|
|
|
|
this.__version__ = '4.6.0'
|
|
|
|
from cutlass_cppgen.backend import create_memory_pool
|
|
from cutlass_cppgen.emit.pytorch import pytorch
|
|
from cutlass_cppgen.op.gemm import Gemm
|
|
from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
|
|
from cutlass_cppgen.op.gemm_grouped import GroupedGemm
|
|
from cutlass_cppgen.op.op import OperationBase
|
|
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
|
|
from cutlass_cppgen.utils.lazy_import import lazy_import
|
|
|
|
|
|
this.memory_pool = None
|
|
def get_memory_pool():
|
|
""""
|
|
Helper method for on-demand memory pool. This avoids allocating the memory pool unnecessarily
|
|
whe CUTLASS is imported.
|
|
"""
|
|
if this.use_rmm and this.memory_pool is None:
|
|
this.memory_pool = create_memory_pool(init_pool_size=2 ** 30, max_pool_size=2 ** 32)
|
|
return this.memory_pool
|
|
|
|
|
|
base_cuda = lazy_import("cuda")
|
|
cuda = lazy_import("cuda.cuda")
|
|
cudart = lazy_import("cuda.cudart")
|
|
|
|
this._device_id = None
|
|
this._nvcc_version = None
|
|
|
|
def check_cuda_versions():
|
|
# Strip any additional information from the CUDA version
|
|
_cuda_version = base_cuda.__version__.split("rc")[0]
|
|
# Check that Python CUDA version exceeds NVCC version
|
|
this._nvcc_version = nvcc_version()
|
|
_cuda_list = _cuda_version.split('.')
|
|
_nvcc_list = this._nvcc_version.split('.')
|
|
for val_cuda, val_nvcc in zip(_cuda_list, _nvcc_list):
|
|
if int(val_cuda) < int(val_nvcc):
|
|
raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {this._nvcc_version}")
|
|
|
|
if len(_nvcc_list) > len(_cuda_list):
|
|
if len(_nvcc_list) != len(_cuda_list) + 1:
|
|
raise Exception(f"Malformatted NVCC version of {this._nvcc_version}")
|
|
if _nvcc_list[:-1] == _cuda_list and int(_nvcc_list[-1]) != 0:
|
|
raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {this._nvcc_version}")
|
|
|
|
def initialize_cuda_context():
|
|
check_cuda_versions()
|
|
|
|
if this._device_id is not None:
|
|
return
|
|
|
|
if this.use_rmm:
|
|
# This also covers initializing the CUDA context
|
|
get_memory_pool()
|
|
|
|
device_id = os.getenv("CUTLASS_CUDA_DEVICE_ID")
|
|
if device_id is None:
|
|
if not this.use_rmm:
|
|
# Manually call cuInit() and create context by making a runtime API call
|
|
err, = cudart.cudaFree(0)
|
|
if err != cudart.cudaError_t.cudaSuccess:
|
|
raise RuntimeError(f"cudaFree failed with error {err}")
|
|
|
|
err, device_count = cuda.cuDeviceGetCount()
|
|
if err != cuda.CUresult.CUDA_SUCCESS:
|
|
raise Exception(f"cuDeviceGetCount failed with error {err}")
|
|
if device_count <= 0:
|
|
raise Exception("No CUDA devices found")
|
|
device_id = 0
|
|
|
|
this._device_id = int(device_id)
|
|
|
|
|
|
def device_id() -> int:
|
|
initialize_cuda_context()
|
|
return this._device_id
|