mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
3.6.0 update (#2005)
* 3.6.0 update * doc and swap stuff --------- Co-authored-by: yuzhai <yuzhai@nvidia.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@@ -57,6 +57,19 @@ CUTLASS_PATH = os.getenv("CUTLASS_PATH", cutlass_library.source_path)
|
||||
# Alias CUTLASS_PATH as source_path
|
||||
source_path = CUTLASS_PATH
|
||||
|
||||
_NVCC_VERSION = None
|
||||
def nvcc_version():
|
||||
global _NVCC_VERSION
|
||||
if _NVCC_VERSION is None:
|
||||
import subprocess
|
||||
|
||||
# Attempt to get NVCC version
|
||||
result = subprocess.run(['nvcc', '--version'], capture_output=True)
|
||||
if result.returncode != 0:
|
||||
raise Exception('Unable to run `nvcc --version')
|
||||
_NVCC_VERSION = str(result.stdout).split(" release ")[-1].split(",")[0]
|
||||
return _NVCC_VERSION
|
||||
|
||||
_CUDA_INSTALL_PATH = None
|
||||
def cuda_install_path():
|
||||
"""
|
||||
|
||||
@@ -139,7 +139,7 @@ def get_tile_scheduler_arguments_3x(
|
||||
splits: int = 1):
|
||||
max_swizzle_size = 1
|
||||
raster_order_option = 0 # Heuristic
|
||||
if tile_scheduler == TileSchedulerType.Persistent:
|
||||
if tile_scheduler in [TileSchedulerType.Default, TileSchedulerType.Persistent]:
|
||||
return _PersistentTileSchedulerArguments(
|
||||
max_swizzle_size,
|
||||
raster_order_option,
|
||||
|
||||
@@ -90,7 +90,7 @@ class CompilationOptions:
|
||||
opts.append(f"--include-path={incl}")
|
||||
|
||||
arch_flag = f"-arch=sm_{self.arch}"
|
||||
if self.arch == 90:
|
||||
if self.arch == 90 and int(cutlass.nvcc_version().split('.')[0]) >= 12:
|
||||
arch_flag += "a"
|
||||
opts.append(arch_flag)
|
||||
|
||||
@@ -237,7 +237,7 @@ class ArtifactManager:
|
||||
if incl not in includes:
|
||||
includes.append(incl)
|
||||
|
||||
includes_host = ["builtin_types.h", "device_launch_parameters.h", "stddef.h"] + includes
|
||||
includes_host = ["builtin_types.h", "device_launch_parameters.h", "cstddef"] + includes
|
||||
for incl in includes:
|
||||
source_buffer_device += SubstituteTemplate(
|
||||
IncludeTemplate,
|
||||
|
||||
@@ -44,6 +44,7 @@ from cutlass.utils.datatypes import is_numpy_tensor, is_torch_available, is_torc
|
||||
|
||||
dtype2ctype = {
|
||||
DataType.f16: ctypes.c_uint16,
|
||||
DataType.bf16: ctypes.c_uint16,
|
||||
DataType.f32: ctypes.c_float,
|
||||
DataType.f64: ctypes.c_double,
|
||||
DataType.s8: ctypes.c_int8,
|
||||
|
||||
@@ -59,18 +59,21 @@ def max(x, dim):
|
||||
elif is_torch_tensor(x):
|
||||
return torch.amax(x, dim)
|
||||
|
||||
|
||||
def maximum(x, y):
|
||||
if is_numpy_tensor(x):
|
||||
return np.maximum(x, y)
|
||||
elif is_torch_tensor(x):
|
||||
return torch.maximum(x, torch.tensor(y))
|
||||
|
||||
|
||||
|
||||
def minimum(x, y):
|
||||
if is_numpy_tensor(x):
|
||||
return np.minimum(x, y)
|
||||
elif is_torch_tensor(x):
|
||||
return torch.minimum(x, torch.tensor(y))
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Layout manipulate nodes
|
||||
##############################################################################
|
||||
|
||||
@@ -51,6 +51,20 @@ _generator_ccs = [50, 60, 61, 70, 75, 80, 90]
|
||||
# Strip any additional information from the CUDA version
|
||||
_cuda_version = __version__.split("rc")[0]
|
||||
|
||||
# Check that Python CUDA version exceeds NVCC version
|
||||
_nvcc_version = cutlass.nvcc_version()
|
||||
_cuda_list = _cuda_version.split('.')
|
||||
_nvcc_list = _cuda_version.split('.')
|
||||
for val_cuda, val_nvcc in zip(_cuda_list, _nvcc_list):
|
||||
if int(val_cuda) < int(val_nvcc):
|
||||
raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {_nvcc_version}")
|
||||
|
||||
if len(_nvcc_list) > len(_cuda_list):
|
||||
if len(_nvcc_list) != len(_cuda_list) + 1:
|
||||
raise Exception(f"Malformatted NVCC version of {_nvcc_version}")
|
||||
if _nvcc_list[:-1] == _cuda_list and int(_nvcc_list[-1]) != 0:
|
||||
raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {_nvcc_version}")
|
||||
|
||||
|
||||
class KernelsForDataType:
|
||||
"""
|
||||
@@ -278,7 +292,7 @@ class ArchOptions:
|
||||
]
|
||||
manifest_args = cutlass_library.generator.define_parser().parse_args(args)
|
||||
manifest = cutlass_library.manifest.Manifest(manifest_args)
|
||||
generate_function(manifest, _cuda_version)
|
||||
generate_function(manifest, _nvcc_version)
|
||||
|
||||
if operation_kind not in manifest.operations:
|
||||
# No kernels generated for this architecture, this could be because the CUDA
|
||||
|
||||
@@ -818,6 +818,7 @@ ${compile_guard_end}
|
||||
element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>"
|
||||
element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>"
|
||||
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
|
||||
is_no_smem_epilogue = operation.epilogue_schedule == EpilogueScheduleType.NoSmemWarpSpecialized
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'operation_suffix': self.operation_suffix,
|
||||
|
||||
@@ -177,7 +177,7 @@ def CreateGemmUniversal3xOperator(
|
||||
complex_transforms=None,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
swizzling_functor=SwizzlingFunctor.Identity1,
|
||||
tile_schedulers=[TileSchedulerType.Persistent]):
|
||||
tile_schedulers=[TileSchedulerType.Default]):
|
||||
|
||||
if type(data_types) is dict:
|
||||
data_types = [data_types]
|
||||
@@ -226,7 +226,7 @@ def CreateSparseGemmUniversal3xOperator(
|
||||
complex_transforms=None,
|
||||
epilogue_functor=EpilogueFunctor.LinearCombination,
|
||||
swizzling_functor=SwizzlingFunctor.Identity1,
|
||||
tile_schedulers=[TileSchedulerType.Persistent]):
|
||||
tile_schedulers=[TileSchedulerType.Default]):
|
||||
|
||||
if type(data_types) is dict:
|
||||
data_types = [data_types]
|
||||
@@ -1048,7 +1048,7 @@ def CreateConvOperator3x(manifest: Manifest,
|
||||
schedule_pairs: Sequence[Tuple[KernelScheduleType, KernelScheduleType]] = \
|
||||
[(KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto)],
|
||||
complex_transforms: Optional[Sequence[ComplexTransform]] = None,
|
||||
tile_schedulers: Sequence[TileSchedulerType] = [TileSchedulerType.Persistent],
|
||||
tile_schedulers: Sequence[TileSchedulerType] = [TileSchedulerType.Default],
|
||||
conv_kind: ConvKind = ConvKind.Fprop,
|
||||
log_indent_level: int = 1):
|
||||
"""
|
||||
@@ -6508,6 +6508,7 @@ def GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version):
|
||||
data_type, alignment_constraints, BlasMode.hermitian)
|
||||
#
|
||||
|
||||
|
||||
###################################################################################################
|
||||
|
||||
def GenerateSM90_Conv3x(manifest, cuda_version,
|
||||
@@ -6703,6 +6704,7 @@ def GenerateSM90_Conv3x(manifest, cuda_version,
|
||||
product(
|
||||
(
|
||||
ConvKind.Dgrad,
|
||||
ConvKind.Wgrad
|
||||
),
|
||||
spatial_dims,
|
||||
(
|
||||
|
||||
@@ -75,6 +75,7 @@ class DataType(enum.Enum):
|
||||
u16 = enum_auto()
|
||||
u32 = enum_auto()
|
||||
u64 = enum_auto()
|
||||
s2 = enum_auto()
|
||||
s4 = enum_auto()
|
||||
s8 = enum_auto()
|
||||
s16 = enum_auto()
|
||||
@@ -92,11 +93,13 @@ class DataType(enum.Enum):
|
||||
cf32 = enum_auto()
|
||||
ctf32 = enum_auto()
|
||||
cf64 = enum_auto()
|
||||
cs2 = enum_auto()
|
||||
cs4 = enum_auto()
|
||||
cs8 = enum_auto()
|
||||
cs16 = enum_auto()
|
||||
cs32 = enum_auto()
|
||||
cs64 = enum_auto()
|
||||
cu2 = enum_auto()
|
||||
cu4 = enum_auto()
|
||||
cu8 = enum_auto()
|
||||
cu16 = enum_auto()
|
||||
@@ -126,6 +129,7 @@ DataTypeNames = {
|
||||
DataType.u16: "u16",
|
||||
DataType.u32: "u32",
|
||||
DataType.u64: "u64",
|
||||
DataType.s2: "s2",
|
||||
DataType.s4: "s4",
|
||||
DataType.s8: "s8",
|
||||
DataType.s16: "s16",
|
||||
@@ -143,11 +147,13 @@ DataTypeNames = {
|
||||
DataType.cf32: "cf32",
|
||||
DataType.ctf32: "ctf32",
|
||||
DataType.cf64: "cf64",
|
||||
DataType.cu2: "cu2",
|
||||
DataType.cu4: "cu4",
|
||||
DataType.cu8: "cu8",
|
||||
DataType.cu16: "cu16",
|
||||
DataType.cu32: "cu32",
|
||||
DataType.cu64: "cu64",
|
||||
DataType.cs2: "cs2",
|
||||
DataType.cs4: "cs4",
|
||||
DataType.cs8: "cs8",
|
||||
DataType.cs16: "cs16",
|
||||
@@ -164,6 +170,7 @@ DataTypeTag = {
|
||||
DataType.u16: "uint16_t",
|
||||
DataType.u32: "uint32_t",
|
||||
DataType.u64: "uint64_t",
|
||||
DataType.s2: "cutlass::int2b_t",
|
||||
DataType.s4: "cutlass::int4b_t",
|
||||
DataType.s8: "int8_t",
|
||||
DataType.s16: "int16_t",
|
||||
@@ -181,11 +188,13 @@ DataTypeTag = {
|
||||
DataType.cf32: "cutlass::complex<float>",
|
||||
DataType.ctf32: "cutlass::complex<cutlass::tfloat32_t>",
|
||||
DataType.cf64: "cutlass::complex<double>",
|
||||
DataType.cu2: "cutlass::complex<cutlass::uint2b_t>",
|
||||
DataType.cu4: "cutlass::complex<cutlass::uint4b_t>",
|
||||
DataType.cu8: "cutlass::complex<cutlass::uint8_t>",
|
||||
DataType.cu16: "cutlass::complex<cutlass::uint16_t>",
|
||||
DataType.cu32: "cutlass::complex<cutlass::uint32_t>",
|
||||
DataType.cu64: "cutlass::complex<cutlass::uint64_t>",
|
||||
DataType.cs2: "cutlass::complex<cutlass::int2b_t>",
|
||||
DataType.cs4: "cutlass::complex<cutlass::int4b_t>",
|
||||
DataType.cs8: "cutlass::complex<cutlass::int8_t>",
|
||||
DataType.cs16: "cutlass::complex<cutlass::int16_t>",
|
||||
@@ -202,6 +211,7 @@ DataTypeSize = {
|
||||
DataType.u16: 16,
|
||||
DataType.u32: 32,
|
||||
DataType.u64: 64,
|
||||
DataType.s2: 2,
|
||||
DataType.s4: 4,
|
||||
DataType.s8: 8,
|
||||
DataType.s16: 16,
|
||||
@@ -219,11 +229,13 @@ DataTypeSize = {
|
||||
DataType.cf32: 64,
|
||||
DataType.ctf32: 32,
|
||||
DataType.cf64: 128,
|
||||
DataType.cu2: 4,
|
||||
DataType.cu4: 8,
|
||||
DataType.cu8: 16,
|
||||
DataType.cu16: 32,
|
||||
DataType.cu32: 64,
|
||||
DataType.cu64: 128,
|
||||
DataType.cs2: 4,
|
||||
DataType.cs4: 8,
|
||||
DataType.cs8: 16,
|
||||
DataType.cs16: 32,
|
||||
|
||||
@@ -492,6 +492,21 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
if not (is_fp8 and is_sparse):
|
||||
schedules.append([KernelScheduleType.TmaWarpSpecialized, default_epilogue])
|
||||
stream_k_schedules = []
|
||||
|
||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 0):
|
||||
if can_do_tma_epilogue:
|
||||
assert not requires_transposed_epilogue
|
||||
# Inconsistency: fp8 pingpong only gets stamped out with fast accum
|
||||
if not is_fp8 or level >= 1:
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedPingpong,
|
||||
EpilogueScheduleType.TmaWarpSpecialized
|
||||
])
|
||||
if can_do_fp8_fast_accum:
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum,
|
||||
EpilogueScheduleType.TmaWarpSpecialized
|
||||
])
|
||||
|
||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||
# Pruning: don't stamp out fp8 ping-ponging kernel with non-tma epilogue
|
||||
@@ -526,17 +541,6 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
# persistent kernels with TMA epilogues
|
||||
if can_do_tma_epilogue:
|
||||
assert not requires_transposed_epilogue
|
||||
# Inconsistency: fp8 pingpong only gets stamped out with fast accum
|
||||
if not is_fp8 or level >= 1:
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedPingpong,
|
||||
EpilogueScheduleType.TmaWarpSpecialized
|
||||
])
|
||||
if can_do_fp8_fast_accum:
|
||||
schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum,
|
||||
EpilogueScheduleType.TmaWarpSpecialized
|
||||
])
|
||||
if can_do_cooperative:
|
||||
# Sparse kernels only support FastAccum FP8 mainloop
|
||||
if not (is_fp8 and is_sparse):
|
||||
|
||||
Reference in New Issue
Block a user