Release v4.0.0 (#2294)

This commit is contained in:
Kihiro Bando
2025-05-13 15:55:29 -04:00
committed by GitHub
parent ad7b2f5e84
commit f115c3f854
299 changed files with 51495 additions and 4413 deletions

View File

@@ -183,10 +183,7 @@ class GemmOperation:
math_op = self.tile_description.math_instruction.math_operation
math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
if self.is_3x:
inst_shape = "{0}x{1}x{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape))
else:
inst_shape = "{0}{1}{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape))
inst_shape = "{0}{1}{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape)) if not self.is_3x else ""
inst_shape += math_op_string
@@ -194,7 +191,9 @@ class GemmOperation:
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
short_math_name = self.short_math_name() if not self.is_3x else ""
return "%s%s%s%s" % (short_math_name, inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
# Generates a string representing the MMA instruction.
def extended_name(self):
@@ -337,18 +336,36 @@ class GemmOperation:
def opcode_class_name(self):
return OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
def get_collective_tile_shape(self):
"""
Get the tile shape passed to the collective builder.
On Blackwell, this is different than the operation.tile_description.tile_shape.
"""
is_sm100_kernel = (self.arch == 100)
if not is_sm100_kernel:
return self.tile_description.tile_shape
opcode_class_main = self.tile_description.math_instruction.opcode_class
instruction_shape = self.tile_description.math_instruction.instruction_shape
tile_shape_m, tile_shape_n, tile_shape_k = self.tile_description.tile_shape
if opcode_class_main in [OpcodeClass.TensorOp, OpcodeClass.BlockScaledTensorOp, OpcodeClass.SparseTensorOp]:
tile_shape_m = instruction_shape[0]
tile_shape_n = instruction_shape[1]
return (tile_shape_m, tile_shape_n, tile_shape_k)
# Generates the full kernel function name
def procedural_name(self):
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
if self.arch >= 90:
kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}{ct}{cs}_{l}_{s}_align{al}{t}{k}{e}"
tile_shape = self.get_collective_tile_shape()
return kernel_name_template.format(
p = self.prefix,
ar = self.arch,
op = opcode_class_name,
ex = self.extended_name_3x(),
ct = '_' + 'x'.join([str(i) for i in self.tile_description.tile_shape]) if self.tile_description.tile_shape[0] > 0 else "",
ct = '_' + 'x'.join([str(i) for i in tile_shape]) if tile_shape[0] > 0 else "",
cs = '_' + 'x'.join([str(i) for i in self.tile_description.cluster_shape]),
l = self.tile_description.stages,
s = self.layout_name_3x(),
@@ -920,28 +937,8 @@ ${compile_guard_end}
instruction_shape = operation.tile_description.math_instruction.instruction_shape
cluster_m = operation.tile_description.cluster_shape[0]
cluster_n = operation.tile_description.cluster_shape[1]
tile_shape_m, tile_shape_n, tile_shape_k = tile_shape
# account for static/dynamic cluster shapes
cta_m = tile_shape[0] // cluster_m if cluster_m > 0 else tile_shape[0]
cta_n = tile_shape[1] // cluster_n if cluster_n > 0 else tile_shape[1]
# Shape passed to epilogue builder
is_sm100_kernel = (operation.arch == 100)
if is_sm100_kernel:
cta_m_per_mma_instruction = 2 if "2sm" in operation.procedural_name() else 1
if cluster_m <= 0:
cta_m = cta_m // cta_m_per_mma_instruction
if opcode_class_main in [OpcodeClass.TensorOp
, OpcodeClass.BlockScaledTensorOp
, OpcodeClass.SparseTensorOp
]:
tile_shape_m = instruction_shape[0]
tile_shape_n = instruction_shape[1]
tile_shape_m, tile_shape_n, tile_shape_k = operation.get_collective_tile_shape()
# stage count set to zero indicates builder automatic stage selection
if operation.tile_description.stages > 0: