3.6.0 update (#2005)

* 3.6.0 update

* doc and swap stuff

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Yujia Zhai
2024-12-24 22:34:40 -08:00
committed by GitHub
parent e1cd8c7866
commit 3d261a5974
258 changed files with 10863 additions and 3883 deletions

View File

@@ -57,6 +57,19 @@ CUTLASS_PATH = os.getenv("CUTLASS_PATH", cutlass_library.source_path)
# Alias CUTLASS_PATH as source_path
source_path = CUTLASS_PATH
_NVCC_VERSION = None
def nvcc_version():
global _NVCC_VERSION
if _NVCC_VERSION is None:
import subprocess
# Attempt to get NVCC version
result = subprocess.run(['nvcc', '--version'], capture_output=True)
if result.returncode != 0:
raise Exception('Unable to run `nvcc --version')
_NVCC_VERSION = str(result.stdout).split(" release ")[-1].split(",")[0]
return _NVCC_VERSION
_CUDA_INSTALL_PATH = None
def cuda_install_path():
"""

View File

@@ -139,7 +139,7 @@ def get_tile_scheduler_arguments_3x(
splits: int = 1):
max_swizzle_size = 1
raster_order_option = 0 # Heuristic
if tile_scheduler == TileSchedulerType.Persistent:
if tile_scheduler in [TileSchedulerType.Default, TileSchedulerType.Persistent]:
return _PersistentTileSchedulerArguments(
max_swizzle_size,
raster_order_option,

View File

@@ -90,7 +90,7 @@ class CompilationOptions:
opts.append(f"--include-path={incl}")
arch_flag = f"-arch=sm_{self.arch}"
if self.arch == 90:
if self.arch == 90 and int(cutlass.nvcc_version().split('.')[0]) >= 12:
arch_flag += "a"
opts.append(arch_flag)
@@ -237,7 +237,7 @@ class ArtifactManager:
if incl not in includes:
includes.append(incl)
includes_host = ["builtin_types.h", "device_launch_parameters.h", "stddef.h"] + includes
includes_host = ["builtin_types.h", "device_launch_parameters.h", "cstddef"] + includes
for incl in includes:
source_buffer_device += SubstituteTemplate(
IncludeTemplate,

View File

@@ -44,6 +44,7 @@ from cutlass.utils.datatypes import is_numpy_tensor, is_torch_available, is_torc
dtype2ctype = {
DataType.f16: ctypes.c_uint16,
DataType.bf16: ctypes.c_uint16,
DataType.f32: ctypes.c_float,
DataType.f64: ctypes.c_double,
DataType.s8: ctypes.c_int8,

View File

@@ -59,18 +59,21 @@ def max(x, dim):
elif is_torch_tensor(x):
return torch.amax(x, dim)
def maximum(x, y):
if is_numpy_tensor(x):
return np.maximum(x, y)
elif is_torch_tensor(x):
return torch.maximum(x, torch.tensor(y))
def minimum(x, y):
if is_numpy_tensor(x):
return np.minimum(x, y)
elif is_torch_tensor(x):
return torch.minimum(x, torch.tensor(y))
##############################################################################
# Layout manipulate nodes
##############################################################################

View File

@@ -51,6 +51,20 @@ _generator_ccs = [50, 60, 61, 70, 75, 80, 90]
# Strip any additional information from the CUDA version
_cuda_version = __version__.split("rc")[0]
# Check that Python CUDA version exceeds NVCC version
_nvcc_version = cutlass.nvcc_version()
_cuda_list = _cuda_version.split('.')
_nvcc_list = _cuda_version.split('.')
for val_cuda, val_nvcc in zip(_cuda_list, _nvcc_list):
if int(val_cuda) < int(val_nvcc):
raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {_nvcc_version}")
if len(_nvcc_list) > len(_cuda_list):
if len(_nvcc_list) != len(_cuda_list) + 1:
raise Exception(f"Malformatted NVCC version of {_nvcc_version}")
if _nvcc_list[:-1] == _cuda_list and int(_nvcc_list[-1]) != 0:
raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {_nvcc_version}")
class KernelsForDataType:
"""
@@ -278,7 +292,7 @@ class ArchOptions:
]
manifest_args = cutlass_library.generator.define_parser().parse_args(args)
manifest = cutlass_library.manifest.Manifest(manifest_args)
generate_function(manifest, _cuda_version)
generate_function(manifest, _nvcc_version)
if operation_kind not in manifest.operations:
# No kernels generated for this architecture, this could be because the CUDA

View File

@@ -818,6 +818,7 @@ ${compile_guard_end}
element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>"
element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>"
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
is_no_smem_epilogue = operation.epilogue_schedule == EpilogueScheduleType.NoSmemWarpSpecialized
values = {
'operation_name': operation.procedural_name(),
'operation_suffix': self.operation_suffix,

View File

@@ -177,7 +177,7 @@ def CreateGemmUniversal3xOperator(
complex_transforms=None,
epilogue_functor=EpilogueFunctor.LinearCombination,
swizzling_functor=SwizzlingFunctor.Identity1,
tile_schedulers=[TileSchedulerType.Persistent]):
tile_schedulers=[TileSchedulerType.Default]):
if type(data_types) is dict:
data_types = [data_types]
@@ -226,7 +226,7 @@ def CreateSparseGemmUniversal3xOperator(
complex_transforms=None,
epilogue_functor=EpilogueFunctor.LinearCombination,
swizzling_functor=SwizzlingFunctor.Identity1,
tile_schedulers=[TileSchedulerType.Persistent]):
tile_schedulers=[TileSchedulerType.Default]):
if type(data_types) is dict:
data_types = [data_types]
@@ -1048,7 +1048,7 @@ def CreateConvOperator3x(manifest: Manifest,
schedule_pairs: Sequence[Tuple[KernelScheduleType, KernelScheduleType]] = \
[(KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto)],
complex_transforms: Optional[Sequence[ComplexTransform]] = None,
tile_schedulers: Sequence[TileSchedulerType] = [TileSchedulerType.Persistent],
tile_schedulers: Sequence[TileSchedulerType] = [TileSchedulerType.Default],
conv_kind: ConvKind = ConvKind.Fprop,
log_indent_level: int = 1):
"""
@@ -6508,6 +6508,7 @@ def GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version):
data_type, alignment_constraints, BlasMode.hermitian)
#
###################################################################################################
def GenerateSM90_Conv3x(manifest, cuda_version,
@@ -6703,6 +6704,7 @@ def GenerateSM90_Conv3x(manifest, cuda_version,
product(
(
ConvKind.Dgrad,
ConvKind.Wgrad
),
spatial_dims,
(

View File

@@ -75,6 +75,7 @@ class DataType(enum.Enum):
u16 = enum_auto()
u32 = enum_auto()
u64 = enum_auto()
s2 = enum_auto()
s4 = enum_auto()
s8 = enum_auto()
s16 = enum_auto()
@@ -92,11 +93,13 @@ class DataType(enum.Enum):
cf32 = enum_auto()
ctf32 = enum_auto()
cf64 = enum_auto()
cs2 = enum_auto()
cs4 = enum_auto()
cs8 = enum_auto()
cs16 = enum_auto()
cs32 = enum_auto()
cs64 = enum_auto()
cu2 = enum_auto()
cu4 = enum_auto()
cu8 = enum_auto()
cu16 = enum_auto()
@@ -126,6 +129,7 @@ DataTypeNames = {
DataType.u16: "u16",
DataType.u32: "u32",
DataType.u64: "u64",
DataType.s2: "s2",
DataType.s4: "s4",
DataType.s8: "s8",
DataType.s16: "s16",
@@ -143,11 +147,13 @@ DataTypeNames = {
DataType.cf32: "cf32",
DataType.ctf32: "ctf32",
DataType.cf64: "cf64",
DataType.cu2: "cu2",
DataType.cu4: "cu4",
DataType.cu8: "cu8",
DataType.cu16: "cu16",
DataType.cu32: "cu32",
DataType.cu64: "cu64",
DataType.cs2: "cs2",
DataType.cs4: "cs4",
DataType.cs8: "cs8",
DataType.cs16: "cs16",
@@ -164,6 +170,7 @@ DataTypeTag = {
DataType.u16: "uint16_t",
DataType.u32: "uint32_t",
DataType.u64: "uint64_t",
DataType.s2: "cutlass::int2b_t",
DataType.s4: "cutlass::int4b_t",
DataType.s8: "int8_t",
DataType.s16: "int16_t",
@@ -181,11 +188,13 @@ DataTypeTag = {
DataType.cf32: "cutlass::complex<float>",
DataType.ctf32: "cutlass::complex<cutlass::tfloat32_t>",
DataType.cf64: "cutlass::complex<double>",
DataType.cu2: "cutlass::complex<cutlass::uint2b_t>",
DataType.cu4: "cutlass::complex<cutlass::uint4b_t>",
DataType.cu8: "cutlass::complex<cutlass::uint8_t>",
DataType.cu16: "cutlass::complex<cutlass::uint16_t>",
DataType.cu32: "cutlass::complex<cutlass::uint32_t>",
DataType.cu64: "cutlass::complex<cutlass::uint64_t>",
DataType.cs2: "cutlass::complex<cutlass::int2b_t>",
DataType.cs4: "cutlass::complex<cutlass::int4b_t>",
DataType.cs8: "cutlass::complex<cutlass::int8_t>",
DataType.cs16: "cutlass::complex<cutlass::int16_t>",
@@ -202,6 +211,7 @@ DataTypeSize = {
DataType.u16: 16,
DataType.u32: 32,
DataType.u64: 64,
DataType.s2: 2,
DataType.s4: 4,
DataType.s8: 8,
DataType.s16: 16,
@@ -219,11 +229,13 @@ DataTypeSize = {
DataType.cf32: 64,
DataType.ctf32: 32,
DataType.cf64: 128,
DataType.cu2: 4,
DataType.cu4: 8,
DataType.cu8: 16,
DataType.cu16: 32,
DataType.cu32: 64,
DataType.cu64: 128,
DataType.cs2: 4,
DataType.cs4: 8,
DataType.cs8: 16,
DataType.cs16: 32,

View File

@@ -492,6 +492,21 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
if not (is_fp8 and is_sparse):
schedules.append([KernelScheduleType.TmaWarpSpecialized, default_epilogue])
stream_k_schedules = []
if CudaToolkitVersionSatisfies(cuda_version, 12, 0):
if can_do_tma_epilogue:
assert not requires_transposed_epilogue
# Inconsistency: fp8 pingpong only gets stamped out with fast accum
if not is_fp8 or level >= 1:
schedules.append([
KernelScheduleType.TmaWarpSpecializedPingpong,
EpilogueScheduleType.TmaWarpSpecialized
])
if can_do_fp8_fast_accum:
schedules.append([
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum,
EpilogueScheduleType.TmaWarpSpecialized
])
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
# Pruning: don't stamp out fp8 ping-ponging kernel with non-tma epilogue
@@ -526,17 +541,6 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
# persistent kernels with TMA epilogues
if can_do_tma_epilogue:
assert not requires_transposed_epilogue
# Inconsistency: fp8 pingpong only gets stamped out with fast accum
if not is_fp8 or level >= 1:
schedules.append([
KernelScheduleType.TmaWarpSpecializedPingpong,
EpilogueScheduleType.TmaWarpSpecialized
])
if can_do_fp8_fast_accum:
schedules.append([
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum,
EpilogueScheduleType.TmaWarpSpecialized
])
if can_do_cooperative:
# Sparse kernels only support FastAccum FP8 mainloop
if not (is_fp8 and is_sparse):