mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-04 21:51:18 +00:00
Release v4.0.0 (#2294)
This commit is contained in:
@@ -183,10 +183,7 @@ class GemmOperation:
|
||||
math_op = self.tile_description.math_instruction.math_operation
|
||||
math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
|
||||
|
||||
if self.is_3x:
|
||||
inst_shape = "{0}x{1}x{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape))
|
||||
else:
|
||||
inst_shape = "{0}{1}{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape))
|
||||
inst_shape = "{0}{1}{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape)) if not self.is_3x else ""
|
||||
|
||||
inst_shape += math_op_string
|
||||
|
||||
@@ -194,7 +191,9 @@ class GemmOperation:
|
||||
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
|
||||
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
||||
|
||||
return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
|
||||
short_math_name = self.short_math_name() if not self.is_3x else ""
|
||||
|
||||
return "%s%s%s%s" % (short_math_name, inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
|
||||
|
||||
# Generates a string representing the MMA instruction.
|
||||
def extended_name(self):
|
||||
@@ -337,18 +336,36 @@ class GemmOperation:
|
||||
def opcode_class_name(self):
|
||||
return OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
||||
|
||||
def get_collective_tile_shape(self):
|
||||
"""
|
||||
Get the tile shape passed to the collective builder.
|
||||
On Blackwell, this is different than the operation.tile_description.tile_shape.
|
||||
"""
|
||||
is_sm100_kernel = (self.arch == 100)
|
||||
if not is_sm100_kernel:
|
||||
return self.tile_description.tile_shape
|
||||
|
||||
opcode_class_main = self.tile_description.math_instruction.opcode_class
|
||||
instruction_shape = self.tile_description.math_instruction.instruction_shape
|
||||
tile_shape_m, tile_shape_n, tile_shape_k = self.tile_description.tile_shape
|
||||
if opcode_class_main in [OpcodeClass.TensorOp, OpcodeClass.BlockScaledTensorOp, OpcodeClass.SparseTensorOp]:
|
||||
tile_shape_m = instruction_shape[0]
|
||||
tile_shape_n = instruction_shape[1]
|
||||
return (tile_shape_m, tile_shape_n, tile_shape_k)
|
||||
|
||||
# Generates the full kernel function name
|
||||
def procedural_name(self):
|
||||
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
||||
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
||||
if self.arch >= 90:
|
||||
kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}{ct}{cs}_{l}_{s}_align{al}{t}{k}{e}"
|
||||
tile_shape = self.get_collective_tile_shape()
|
||||
return kernel_name_template.format(
|
||||
p = self.prefix,
|
||||
ar = self.arch,
|
||||
op = opcode_class_name,
|
||||
ex = self.extended_name_3x(),
|
||||
ct = '_' + 'x'.join([str(i) for i in self.tile_description.tile_shape]) if self.tile_description.tile_shape[0] > 0 else "",
|
||||
ct = '_' + 'x'.join([str(i) for i in tile_shape]) if tile_shape[0] > 0 else "",
|
||||
cs = '_' + 'x'.join([str(i) for i in self.tile_description.cluster_shape]),
|
||||
l = self.tile_description.stages,
|
||||
s = self.layout_name_3x(),
|
||||
@@ -920,28 +937,8 @@ ${compile_guard_end}
|
||||
instruction_shape = operation.tile_description.math_instruction.instruction_shape
|
||||
cluster_m = operation.tile_description.cluster_shape[0]
|
||||
cluster_n = operation.tile_description.cluster_shape[1]
|
||||
|
||||
tile_shape_m, tile_shape_n, tile_shape_k = tile_shape
|
||||
|
||||
# account for static/dynamic cluster shapes
|
||||
cta_m = tile_shape[0] // cluster_m if cluster_m > 0 else tile_shape[0]
|
||||
cta_n = tile_shape[1] // cluster_n if cluster_n > 0 else tile_shape[1]
|
||||
|
||||
|
||||
# Shape passed to epilogue builder
|
||||
is_sm100_kernel = (operation.arch == 100)
|
||||
if is_sm100_kernel:
|
||||
cta_m_per_mma_instruction = 2 if "2sm" in operation.procedural_name() else 1
|
||||
if cluster_m <= 0:
|
||||
cta_m = cta_m // cta_m_per_mma_instruction
|
||||
|
||||
if opcode_class_main in [OpcodeClass.TensorOp
|
||||
, OpcodeClass.BlockScaledTensorOp
|
||||
, OpcodeClass.SparseTensorOp
|
||||
]:
|
||||
tile_shape_m = instruction_shape[0]
|
||||
tile_shape_n = instruction_shape[1]
|
||||
|
||||
tile_shape_m, tile_shape_n, tile_shape_k = operation.get_collective_tile_shape()
|
||||
|
||||
# stage count set to zero indicates builder automatic stage selection
|
||||
if operation.tile_description.stages > 0:
|
||||
|
||||
Reference in New Issue
Block a user