mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
Release v4.0.0 (#2294)
This commit is contained in:
@@ -75,15 +75,10 @@ audit_csv_runtime_fields = [
|
||||
]
|
||||
|
||||
def hash_cutlass_string(input_string):
|
||||
# Regex pattern to match instruction shape
|
||||
instruction_shape_pattern = r"[a-zA-Z]\d+x\d+x\d+" # Matches '_s128x128x64', '_h64x128x16', etc.
|
||||
mma_cluster_shape_pattern = r"_\d+x\d+x\d+" # Matches MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1')
|
||||
|
||||
# Remove instruction shape (e.g., '_s128x128x64', '_h64x128x16')
|
||||
output = re.sub(instruction_shape_pattern, "", input_string)
|
||||
|
||||
# Remove MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1')
|
||||
output = re.sub(mma_cluster_shape_pattern, "", output)
|
||||
output = re.sub(mma_cluster_shape_pattern, "", input_string)
|
||||
|
||||
return output
|
||||
|
||||
@@ -288,7 +283,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
# TODO: randomize beta values for wider coverage
|
||||
beta_values = [0.5]
|
||||
|
||||
is_supported_arch = (arch in ["100a", "101a", "120a"])
|
||||
is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "120a", "120f"])
|
||||
|
||||
is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch
|
||||
|
||||
@@ -300,23 +295,23 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
#
|
||||
|
||||
sm100_mma_data_type_general = [
|
||||
'x16gemm_f16_f16_f16_f16_f16',
|
||||
'x16gemm_f16_f16_f16_void_f16',
|
||||
'x16gemm_f16_f16_f32_f16_f16',
|
||||
'x8tf32gemm_f32_f32_f32_f32_f32',
|
||||
'x16bf16gemm_f32_f32_f32_f32_f32',
|
||||
'gemm_f16_f16_f16_f16_f16',
|
||||
'gemm_f16_f16_f16_void_f16',
|
||||
'gemm_f16_f16_f32_f16_f16',
|
||||
'tf32gemm_f32_f32_f32_f32_f32',
|
||||
'bf16gemm_f32_f32_f32_f32_f32',
|
||||
]
|
||||
|
||||
sm100_mma_data_type_runtime_dtype = [
|
||||
'x32gemm_f4_f4_f32_f32_f32',
|
||||
'x32gemm_f6_f6_f32_f32_f32',
|
||||
'x32gemm_f8_f8_f32_f32_f32',
|
||||
'gemm_f4_f4_f32_f32_f32',
|
||||
'gemm_f6_f6_f32_f32_f32',
|
||||
'gemm_f8_f8_f32_f32_f32',
|
||||
]
|
||||
|
||||
sm100_mma_data_type_mergeable = [
|
||||
'x32gemm_e4m3_e4m3_f32_f32_f32',# mask out one instance for verification
|
||||
'x32gemm_e2m1_e2m1_f32_f32_f32',
|
||||
'x32gemm_e3m2_e3m2_f32_f32_f32',
|
||||
'gemm_e4m3_e4m3_f32_f32_f32',# mask out one instance for verification
|
||||
'gemm_e2m1_e2m1_f32_f32_f32',
|
||||
'gemm_e3m2_e3m2_f32_f32_f32',
|
||||
]
|
||||
|
||||
sm100_mma_cluster_size = [
|
||||
@@ -331,22 +326,15 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
'ntn'
|
||||
]
|
||||
|
||||
sm100_mma_instruction_shape = [
|
||||
# [0] .1CTA, General
|
||||
['64x128', '128x128', '128x256'],
|
||||
# [1] .2CTA, General
|
||||
['128x128', '256x128', '256x256'],
|
||||
]
|
||||
|
||||
# regex list must be in kernel procedural name order
|
||||
mergeable_sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
mergeable_sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
mergeable_sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
mergeable_sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
|
||||
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
|
||||
sm100_mma_filter_regex_1sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
sm100_mma_filter_regex_2sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
sm100_mma_filter_regex_1sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
sm100_mma_filter_regex_2sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
|
||||
#
|
||||
# Block Scale Gemm
|
||||
@@ -354,19 +342,19 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
|
||||
block_scaled_data_type_base = [
|
||||
# runtime datatypes
|
||||
'x32gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
|
||||
'x64gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
|
||||
'x32gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2',
|
||||
'x64gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
|
||||
'x32gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
|
||||
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
|
||||
'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
|
||||
]
|
||||
|
||||
block_scaled_data_type_mergeable = [
|
||||
'x32gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
||||
'x64gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
||||
'x32gemm.*ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
|
||||
'x64gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
|
||||
'x32gemm.*ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
|
||||
'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
|
||||
'gemm.*ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
|
||||
]
|
||||
|
||||
block_scaled_data_type = block_scaled_data_type_base + block_scaled_data_type_mergeable
|
||||
@@ -377,56 +365,43 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
]
|
||||
|
||||
block_scaled_layouts = ['tnt']
|
||||
block_scaled_instruction_shape = [
|
||||
# .1CTA
|
||||
['128x128', '128x192', '128x256'],
|
||||
# .2CTA
|
||||
['256x128', '256x192', '256x256'],
|
||||
]
|
||||
# regex list must be in kernel procedural name order
|
||||
mergeable_block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[0], block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
mergeable_block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[1], block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
mergeable_block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
mergeable_block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
|
||||
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[0], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[1], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
|
||||
if arch == "100a":
|
||||
if arch == "100a" or arch == "100f":
|
||||
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({sm100_mma_filter_regex_1sm_runtime})|" \
|
||||
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})"
|
||||
elif arch == "101a":
|
||||
elif arch == "101a" or arch == "101f":
|
||||
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({sm100_mma_filter_regex_1sm_runtime})|" \
|
||||
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})"
|
||||
elif arch == "120a":
|
||||
elif arch == "120a" or arch == "120f":
|
||||
|
||||
# blockscaled sm120_mma kernels
|
||||
blockscaled_sm120_mma_kernel_cta_tiles = [
|
||||
[ '128x128' ]
|
||||
]
|
||||
|
||||
# sm120 MMA instruction shapes
|
||||
blockscaled_sm120_mma_instruction_shapes = [
|
||||
[ 's16x8x64gemm',
|
||||
's16x8x32gemm'
|
||||
]
|
||||
]
|
||||
|
||||
# Restrict to two layouts to reduce L0 build and test time.
|
||||
blockscaled_sm120_mma_layouts = [ 'tn' ]
|
||||
filter_regex_blockscaled_sm120_mma = "cutlass3x_sm120_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [blockscaled_sm120_mma_instruction_shapes[0], blockscaled_sm120_mma_kernel_cta_tiles[0], blockscaled_sm120_mma_layouts]]) + ").*"
|
||||
filter_regex_blockscaled_sm120_mma = "cutlass3x_sm120_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [blockscaled_sm120_mma_kernel_cta_tiles[0], blockscaled_sm120_mma_layouts]]) + ").*"
|
||||
|
||||
problem_waves = [0.5, 1.25, 2.5]
|
||||
|
||||
kernel_filter = f"({filter_regex_blockscaled_sm120_mma})"
|
||||
else:
|
||||
error_message = "unsupported arch, only support sm100a, sm101a, sm120a"
|
||||
error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm120a, sm120f"
|
||||
raise Exception(error_message)
|
||||
|
||||
# Statically encoded kernels are still added to generated_kernels
|
||||
@@ -445,14 +420,8 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
]
|
||||
# Restrict to two layouts to reduce L1 build and test time.
|
||||
sm100_mma_layouts = ['tnt', 'ntn']
|
||||
sm100_mma_instruction_shape = [
|
||||
# .1CTA
|
||||
['64x128', '128x128', '128x256'],
|
||||
# .2CTA
|
||||
['128x128', '256x128', '256x256']
|
||||
]
|
||||
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
block_scaled_data_type = [
|
||||
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
||||
'ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
|
||||
@@ -463,15 +432,10 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
|
||||
block_scaled_cluster_size = ['4x4x1', '2x1x1', '0x0x1']
|
||||
block_scaled_layouts = ['tnt']
|
||||
block_scaled_instruction_shape = [
|
||||
# .1CTA
|
||||
['128x128', '128x192', '128x256'],
|
||||
# .2CTA
|
||||
['256x128', '256x192', '256x256'],
|
||||
]
|
||||
|
||||
# regex list must be in kernel procedural name order
|
||||
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[0], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[1], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
|
||||
@@ -183,10 +183,7 @@ class GemmOperation:
|
||||
math_op = self.tile_description.math_instruction.math_operation
|
||||
math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
|
||||
|
||||
if self.is_3x:
|
||||
inst_shape = "{0}x{1}x{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape))
|
||||
else:
|
||||
inst_shape = "{0}{1}{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape))
|
||||
inst_shape = "{0}{1}{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape)) if not self.is_3x else ""
|
||||
|
||||
inst_shape += math_op_string
|
||||
|
||||
@@ -194,7 +191,9 @@ class GemmOperation:
|
||||
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
|
||||
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
||||
|
||||
return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
|
||||
short_math_name = self.short_math_name() if not self.is_3x else ""
|
||||
|
||||
return "%s%s%s%s" % (short_math_name, inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
|
||||
|
||||
# Generates a string representing the MMA instruction.
|
||||
def extended_name(self):
|
||||
@@ -337,18 +336,36 @@ class GemmOperation:
|
||||
def opcode_class_name(self):
|
||||
return OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
||||
|
||||
def get_collective_tile_shape(self):
|
||||
"""
|
||||
Get the tile shape passed to the collective builder.
|
||||
On Blackwell, this is different than the operation.tile_description.tile_shape.
|
||||
"""
|
||||
is_sm100_kernel = (self.arch == 100)
|
||||
if not is_sm100_kernel:
|
||||
return self.tile_description.tile_shape
|
||||
|
||||
opcode_class_main = self.tile_description.math_instruction.opcode_class
|
||||
instruction_shape = self.tile_description.math_instruction.instruction_shape
|
||||
tile_shape_m, tile_shape_n, tile_shape_k = self.tile_description.tile_shape
|
||||
if opcode_class_main in [OpcodeClass.TensorOp, OpcodeClass.BlockScaledTensorOp, OpcodeClass.SparseTensorOp]:
|
||||
tile_shape_m = instruction_shape[0]
|
||||
tile_shape_n = instruction_shape[1]
|
||||
return (tile_shape_m, tile_shape_n, tile_shape_k)
|
||||
|
||||
# Generates the full kernel function name
|
||||
def procedural_name(self):
|
||||
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
||||
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
||||
if self.arch >= 90:
|
||||
kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}{ct}{cs}_{l}_{s}_align{al}{t}{k}{e}"
|
||||
tile_shape = self.get_collective_tile_shape()
|
||||
return kernel_name_template.format(
|
||||
p = self.prefix,
|
||||
ar = self.arch,
|
||||
op = opcode_class_name,
|
||||
ex = self.extended_name_3x(),
|
||||
ct = '_' + 'x'.join([str(i) for i in self.tile_description.tile_shape]) if self.tile_description.tile_shape[0] > 0 else "",
|
||||
ct = '_' + 'x'.join([str(i) for i in tile_shape]) if tile_shape[0] > 0 else "",
|
||||
cs = '_' + 'x'.join([str(i) for i in self.tile_description.cluster_shape]),
|
||||
l = self.tile_description.stages,
|
||||
s = self.layout_name_3x(),
|
||||
@@ -920,28 +937,8 @@ ${compile_guard_end}
|
||||
instruction_shape = operation.tile_description.math_instruction.instruction_shape
|
||||
cluster_m = operation.tile_description.cluster_shape[0]
|
||||
cluster_n = operation.tile_description.cluster_shape[1]
|
||||
|
||||
tile_shape_m, tile_shape_n, tile_shape_k = tile_shape
|
||||
|
||||
# account for static/dynamic cluster shapes
|
||||
cta_m = tile_shape[0] // cluster_m if cluster_m > 0 else tile_shape[0]
|
||||
cta_n = tile_shape[1] // cluster_n if cluster_n > 0 else tile_shape[1]
|
||||
|
||||
|
||||
# Shape passed to epilogue builder
|
||||
is_sm100_kernel = (operation.arch == 100)
|
||||
if is_sm100_kernel:
|
||||
cta_m_per_mma_instruction = 2 if "2sm" in operation.procedural_name() else 1
|
||||
if cluster_m <= 0:
|
||||
cta_m = cta_m // cta_m_per_mma_instruction
|
||||
|
||||
if opcode_class_main in [OpcodeClass.TensorOp
|
||||
, OpcodeClass.BlockScaledTensorOp
|
||||
, OpcodeClass.SparseTensorOp
|
||||
]:
|
||||
tile_shape_m = instruction_shape[0]
|
||||
tile_shape_n = instruction_shape[1]
|
||||
|
||||
tile_shape_m, tile_shape_n, tile_shape_k = operation.get_collective_tile_shape()
|
||||
|
||||
# stage count set to zero indicates builder automatic stage selection
|
||||
if operation.tile_description.stages > 0:
|
||||
|
||||
@@ -1003,14 +1003,11 @@ class ConvOperation3x:
|
||||
math_op = self.tile_description.math_instruction.math_operation
|
||||
math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
|
||||
|
||||
inst_shape = "{0}x{1}x{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape))
|
||||
inst_shape += math_op_string
|
||||
|
||||
if self.tile_description.math_instruction.element_a != self.A.element and \
|
||||
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
|
||||
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
||||
|
||||
return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, ConvKindNames[self.conv_kind])
|
||||
return "%s%s%s" % (math_op_string, intermediate_type, ConvKindNames[self.conv_kind])
|
||||
|
||||
def extended_name(self):
|
||||
'''Generates a string representing the MMA atom. Assumes accumulator type is C type.'''
|
||||
@@ -5997,8 +5994,8 @@ def GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version):
|
||||
|
||||
math_instructions = generate_mixed_dtype_math_instructions_sm90(instantiation_level, valid_types_for_a_b_acc)
|
||||
|
||||
valid_types_for_d = [DataType.f32]
|
||||
valid_types_for_c = [DataType.f32]
|
||||
valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2]
|
||||
valid_types_for_c = copy.deepcopy(valid_types_for_d)
|
||||
|
||||
tile_descriptions = generate_tile_descriptions_sm90(
|
||||
math_instructions=math_instructions,
|
||||
@@ -6009,6 +6006,12 @@ def GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version):
|
||||
math_inst = tile_desc.math_instruction
|
||||
data_types = []
|
||||
|
||||
# Limit C/D types to avoid a giant number of instantiations.
|
||||
# A typical use case for mixed dtype in DL is weight quantization (tensor A),
|
||||
# therefore we can limit the output type to that of activation (tensor B).
|
||||
valid_types_for_c = [math_inst.element_b]
|
||||
valid_types_for_d = [math_inst.element_b]
|
||||
|
||||
for c_type, d_type in product(valid_types_for_c, valid_types_for_d):
|
||||
data_types.append(
|
||||
generate_data_types_from_math_instruction(
|
||||
@@ -6791,6 +6794,11 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default
|
||||
]
|
||||
@@ -6838,6 +6846,11 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
@@ -6937,6 +6950,11 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default
|
||||
]
|
||||
@@ -7090,6 +7108,11 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
@@ -7247,6 +7270,11 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default,
|
||||
]
|
||||
@@ -7456,6 +7484,11 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
@@ -7916,6 +7949,13 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [
|
||||
[2,1,1],
|
||||
[1,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
# 1xSM MMA kernels
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
@@ -7985,6 +8025,12 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
@@ -8138,6 +8184,13 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [
|
||||
[1,1,1],
|
||||
[2,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
# 1xSM MMA kernels
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
@@ -8211,6 +8264,13 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1],
|
||||
[4,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
@@ -8417,6 +8477,13 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [
|
||||
[1,1,1],
|
||||
[2,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
# 1xSM MMA kernels
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
@@ -8537,6 +8604,13 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1],
|
||||
[4,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
@@ -8689,6 +8763,11 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default,
|
||||
]
|
||||
@@ -8788,6 +8867,11 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
@@ -8925,6 +9009,9 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_1sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
@@ -8953,6 +9040,9 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_2sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
@@ -9044,6 +9134,9 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_1sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
@@ -9072,6 +9165,9 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_2sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
@@ -9163,6 +9259,9 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_1sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
@@ -9191,6 +9290,9 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_2sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
@@ -9287,6 +9389,9 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_1sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
@@ -9319,6 +9424,9 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_2sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
@@ -9417,6 +9525,9 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_1sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
@@ -9476,6 +9587,9 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_2sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
@@ -9578,6 +9692,12 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [
|
||||
[1,2,1], [1,1,1], [1,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.StreamK,
|
||||
]
|
||||
@@ -9612,6 +9732,12 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1], [2,2,1], [2,4,1], [4,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
@@ -9658,6 +9784,12 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [
|
||||
[1,2,1], [1,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.StreamK
|
||||
]
|
||||
@@ -9726,6 +9858,12 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1], [2,2,1], [2,4,1], [4,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
@@ -9809,6 +9947,12 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [
|
||||
[1,2,1], [2,1,1], [1,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.StreamK,
|
||||
]
|
||||
@@ -9861,6 +10005,12 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1], [2,2,1], [2,4,1], [4,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
@@ -9960,6 +10110,9 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version,
|
||||
|
||||
cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1],[4,4,1]]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1]]
|
||||
|
||||
# tile_descriptions is a 2-level list.
|
||||
# Each inner list is for each cluster shape.
|
||||
for math_inst, output_type in math_instructions_w_output_1sm:
|
||||
@@ -10023,6 +10176,8 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version,
|
||||
data_types_and_instruction_shapes_2sm)
|
||||
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]]
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]]
|
||||
|
||||
for math_inst, output_type in math_instructions_w_output_2sm:
|
||||
tile_descriptions = []
|
||||
@@ -10103,6 +10258,8 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version,
|
||||
data_types_and_instruction_shapes_1sm)
|
||||
|
||||
cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1],[4,4,1]]
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1]]
|
||||
|
||||
for math_inst, output_type in math_instructions_w_output_1sm:
|
||||
tile_descriptions = []
|
||||
@@ -10166,6 +10323,8 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version,
|
||||
data_types_and_instruction_shapes_2sm)
|
||||
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]]
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]]
|
||||
|
||||
for math_inst, output_type in math_instructions_w_output_2sm:
|
||||
tile_descriptions = []
|
||||
@@ -10629,6 +10788,8 @@ def GenerateSM100(manifest, cuda_version):
|
||||
#
|
||||
# Dense Gemm
|
||||
#
|
||||
architectures = manifest.args.architectures.split(';') if len(args.architectures) else ['50',]
|
||||
|
||||
GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version)
|
||||
|
||||
GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version)
|
||||
@@ -10636,7 +10797,8 @@ def GenerateSM100(manifest, cuda_version):
|
||||
|
||||
GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version)
|
||||
|
||||
GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version)
|
||||
if '100f' not in architectures and '101f' not in architectures:
|
||||
GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version)
|
||||
|
||||
GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version)
|
||||
# grouped GEMM
|
||||
@@ -10657,7 +10819,8 @@ def GenerateSM100(manifest, cuda_version):
|
||||
#
|
||||
GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version)
|
||||
GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version)
|
||||
GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version)
|
||||
if '100f' not in architectures and '101f' not in architectures:
|
||||
GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version)
|
||||
GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version)
|
||||
GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version)
|
||||
|
||||
@@ -11166,7 +11329,7 @@ if __name__ == "__main__":
|
||||
GenerateSM89(manifest, args.cuda_version)
|
||||
GenerateSM90(manifest, args.cuda_version)
|
||||
|
||||
blackwell_enabled_arch = any(arch in ["100a", "101a", "120a"] for arch in archs)
|
||||
blackwell_enabled_arch = any(arch in ["100a", "100f", "101a", "101f", "120a", "120f"] for arch in archs)
|
||||
if blackwell_enabled_arch:
|
||||
GenerateSM100(manifest, args.cuda_version)
|
||||
GenerateSM120(manifest, args.cuda_version)
|
||||
|
||||
@@ -523,10 +523,14 @@ class Manifest:
|
||||
arch_conditional_cc = [
|
||||
'90a',
|
||||
'100a',
|
||||
'100f',
|
||||
'101a',
|
||||
'120a'
|
||||
'101f',
|
||||
'120a',
|
||||
'120f'
|
||||
]
|
||||
architectures = [x if x not in arch_conditional_cc else x.split('a')[0] for x in architectures]
|
||||
architectures = [x if x not in arch_conditional_cc else x.split('f')[0] for x in architectures]
|
||||
|
||||
self.compute_capabilities = [int(x) for x in architectures]
|
||||
|
||||
|
||||
@@ -375,6 +375,13 @@ def generate_tile_descriptions_sm90(math_instructions, is_aligned: bool, level:
|
||||
mma_multipliers, cluster_sizes = get_mma_multipliers(level), get_cluster_sizes(level, is_aligned)
|
||||
for math_inst, mma_mul, cluster_size in product(math_instructions, mma_multipliers, cluster_sizes):
|
||||
|
||||
# generator can stamp out duplicate kernels, because it doesn't explicitly set instruction
|
||||
# shape for SM90 kernels, and the 3.X collective API doesn't directly expose them when using
|
||||
# the auto kernel schedule.
|
||||
|
||||
math_inst_stub = copy.deepcopy(math_inst)
|
||||
math_inst_stub.instruction_shape = [0, 0, 0]
|
||||
|
||||
tile_desc = TileDescription(
|
||||
threadblock_shape=[
|
||||
math_inst.instruction_shape[0] * mma_mul[0],
|
||||
@@ -383,7 +390,7 @@ def generate_tile_descriptions_sm90(math_instructions, is_aligned: bool, level:
|
||||
],
|
||||
stages=0,
|
||||
warp_count=[4, 1, 1],
|
||||
math_instruction=math_inst,
|
||||
math_instruction=math_inst_stub,
|
||||
min_compute=90,
|
||||
max_compute=90,
|
||||
cluster_shape=cluster_size)
|
||||
@@ -551,6 +558,7 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
b_type_size = DataTypeSize[data_types["b_type"]]
|
||||
if a_type_size != b_type_size and CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||
schedules = []
|
||||
stream_k_schedules = []
|
||||
epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized
|
||||
if a_type_size > b_type_size:
|
||||
epilogue_schedule = EpilogueScheduleType.EpilogueTransposed
|
||||
@@ -579,7 +587,11 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
epilogue_schedule
|
||||
])
|
||||
return schedules, []
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
epilogue_schedule
|
||||
])
|
||||
return schedules, stream_k_schedules
|
||||
|
||||
if not is_aligned and not is_blockwise(gemm_kind):
|
||||
schedules = [[KernelScheduleType.CpAsyncWarpSpecialized,
|
||||
|
||||
Reference in New Issue
Block a user