mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-03 21:21:16 +00:00
Release v4.0.0 (#2294)
This commit is contained in:
@@ -75,15 +75,10 @@ audit_csv_runtime_fields = [
|
||||
]
|
||||
|
||||
def hash_cutlass_string(input_string):
|
||||
# Regex pattern to match instruction shape
|
||||
instruction_shape_pattern = r"[a-zA-Z]\d+x\d+x\d+" # Matches '_s128x128x64', '_h64x128x16', etc.
|
||||
mma_cluster_shape_pattern = r"_\d+x\d+x\d+" # Matches MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1')
|
||||
|
||||
# Remove instruction shape (e.g., '_s128x128x64', '_h64x128x16')
|
||||
output = re.sub(instruction_shape_pattern, "", input_string)
|
||||
|
||||
# Remove MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1')
|
||||
output = re.sub(mma_cluster_shape_pattern, "", output)
|
||||
output = re.sub(mma_cluster_shape_pattern, "", input_string)
|
||||
|
||||
return output
|
||||
|
||||
@@ -288,7 +283,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
# TODO: randomize beta values for wider coverage
|
||||
beta_values = [0.5]
|
||||
|
||||
is_supported_arch = (arch in ["100a", "101a", "120a"])
|
||||
is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "120a", "120f"])
|
||||
|
||||
is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch
|
||||
|
||||
@@ -300,23 +295,23 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
#
|
||||
|
||||
sm100_mma_data_type_general = [
|
||||
'x16gemm_f16_f16_f16_f16_f16',
|
||||
'x16gemm_f16_f16_f16_void_f16',
|
||||
'x16gemm_f16_f16_f32_f16_f16',
|
||||
'x8tf32gemm_f32_f32_f32_f32_f32',
|
||||
'x16bf16gemm_f32_f32_f32_f32_f32',
|
||||
'gemm_f16_f16_f16_f16_f16',
|
||||
'gemm_f16_f16_f16_void_f16',
|
||||
'gemm_f16_f16_f32_f16_f16',
|
||||
'tf32gemm_f32_f32_f32_f32_f32',
|
||||
'bf16gemm_f32_f32_f32_f32_f32',
|
||||
]
|
||||
|
||||
sm100_mma_data_type_runtime_dtype = [
|
||||
'x32gemm_f4_f4_f32_f32_f32',
|
||||
'x32gemm_f6_f6_f32_f32_f32',
|
||||
'x32gemm_f8_f8_f32_f32_f32',
|
||||
'gemm_f4_f4_f32_f32_f32',
|
||||
'gemm_f6_f6_f32_f32_f32',
|
||||
'gemm_f8_f8_f32_f32_f32',
|
||||
]
|
||||
|
||||
sm100_mma_data_type_mergeable = [
|
||||
'x32gemm_e4m3_e4m3_f32_f32_f32',# mask out one instance for verification
|
||||
'x32gemm_e2m1_e2m1_f32_f32_f32',
|
||||
'x32gemm_e3m2_e3m2_f32_f32_f32',
|
||||
'gemm_e4m3_e4m3_f32_f32_f32',# mask out one instance for verification
|
||||
'gemm_e2m1_e2m1_f32_f32_f32',
|
||||
'gemm_e3m2_e3m2_f32_f32_f32',
|
||||
]
|
||||
|
||||
sm100_mma_cluster_size = [
|
||||
@@ -331,22 +326,15 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
'ntn'
|
||||
]
|
||||
|
||||
sm100_mma_instruction_shape = [
|
||||
# [0] .1CTA, General
|
||||
['64x128', '128x128', '128x256'],
|
||||
# [1] .2CTA, General
|
||||
['128x128', '256x128', '256x256'],
|
||||
]
|
||||
|
||||
# regex list must be in kernel procedural name order
|
||||
mergeable_sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
mergeable_sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
mergeable_sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
mergeable_sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
|
||||
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
|
||||
sm100_mma_filter_regex_1sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
sm100_mma_filter_regex_2sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
sm100_mma_filter_regex_1sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
sm100_mma_filter_regex_2sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
|
||||
#
|
||||
# Block Scale Gemm
|
||||
@@ -354,19 +342,19 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
|
||||
block_scaled_data_type_base = [
|
||||
# runtime datatypes
|
||||
'x32gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
|
||||
'x64gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
|
||||
'x32gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2',
|
||||
'x64gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
|
||||
'x32gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
|
||||
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
|
||||
'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
|
||||
]
|
||||
|
||||
block_scaled_data_type_mergeable = [
|
||||
'x32gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
||||
'x64gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
||||
'x32gemm.*ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
|
||||
'x64gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
|
||||
'x32gemm.*ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
|
||||
'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
|
||||
'gemm.*ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
|
||||
]
|
||||
|
||||
block_scaled_data_type = block_scaled_data_type_base + block_scaled_data_type_mergeable
|
||||
@@ -377,56 +365,43 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
]
|
||||
|
||||
block_scaled_layouts = ['tnt']
|
||||
block_scaled_instruction_shape = [
|
||||
# .1CTA
|
||||
['128x128', '128x192', '128x256'],
|
||||
# .2CTA
|
||||
['256x128', '256x192', '256x256'],
|
||||
]
|
||||
# regex list must be in kernel procedural name order
|
||||
mergeable_block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[0], block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
mergeable_block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[1], block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
mergeable_block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
mergeable_block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
|
||||
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[0], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[1], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
|
||||
if arch == "100a":
|
||||
if arch == "100a" or arch == "100f":
|
||||
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({sm100_mma_filter_regex_1sm_runtime})|" \
|
||||
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})"
|
||||
elif arch == "101a":
|
||||
elif arch == "101a" or arch == "101f":
|
||||
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({sm100_mma_filter_regex_1sm_runtime})|" \
|
||||
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})"
|
||||
elif arch == "120a":
|
||||
elif arch == "120a" or arch == "120f":
|
||||
|
||||
# blockscaled sm120_mma kernels
|
||||
blockscaled_sm120_mma_kernel_cta_tiles = [
|
||||
[ '128x128' ]
|
||||
]
|
||||
|
||||
# sm120 MMA instruction shapes
|
||||
blockscaled_sm120_mma_instruction_shapes = [
|
||||
[ 's16x8x64gemm',
|
||||
's16x8x32gemm'
|
||||
]
|
||||
]
|
||||
|
||||
# Restrict to two layouts to reduce L0 build and test time.
|
||||
blockscaled_sm120_mma_layouts = [ 'tn' ]
|
||||
filter_regex_blockscaled_sm120_mma = "cutlass3x_sm120_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [blockscaled_sm120_mma_instruction_shapes[0], blockscaled_sm120_mma_kernel_cta_tiles[0], blockscaled_sm120_mma_layouts]]) + ").*"
|
||||
filter_regex_blockscaled_sm120_mma = "cutlass3x_sm120_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [blockscaled_sm120_mma_kernel_cta_tiles[0], blockscaled_sm120_mma_layouts]]) + ").*"
|
||||
|
||||
problem_waves = [0.5, 1.25, 2.5]
|
||||
|
||||
kernel_filter = f"({filter_regex_blockscaled_sm120_mma})"
|
||||
else:
|
||||
error_message = "unsupported arch, only support sm100a, sm101a, sm120a"
|
||||
error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm120a, sm120f"
|
||||
raise Exception(error_message)
|
||||
|
||||
# Statically encoded kernels are still added to generated_kernels
|
||||
@@ -445,14 +420,8 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
]
|
||||
# Restrict to two layouts to reduce L1 build and test time.
|
||||
sm100_mma_layouts = ['tnt', 'ntn']
|
||||
sm100_mma_instruction_shape = [
|
||||
# .1CTA
|
||||
['64x128', '128x128', '128x256'],
|
||||
# .2CTA
|
||||
['128x128', '256x128', '256x256']
|
||||
]
|
||||
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
block_scaled_data_type = [
|
||||
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
||||
'ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
|
||||
@@ -463,15 +432,10 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
|
||||
block_scaled_cluster_size = ['4x4x1', '2x1x1', '0x0x1']
|
||||
block_scaled_layouts = ['tnt']
|
||||
block_scaled_instruction_shape = [
|
||||
# .1CTA
|
||||
['128x128', '128x192', '128x256'],
|
||||
# .2CTA
|
||||
['256x128', '256x192', '256x256'],
|
||||
]
|
||||
|
||||
# regex list must be in kernel procedural name order
|
||||
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[0], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[1], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
|
||||
Reference in New Issue
Block a user