Release v4.0.0 (#2294)

This commit is contained in:
Kihiro Bando
2025-05-13 15:55:29 -04:00
committed by GitHub
parent ad7b2f5e84
commit f115c3f854
299 changed files with 51495 additions and 4413 deletions

View File

@@ -75,15 +75,10 @@ audit_csv_runtime_fields = [
]
def hash_cutlass_string(input_string):
# Regex pattern to match instruction shape
instruction_shape_pattern = r"[a-zA-Z]\d+x\d+x\d+" # Matches '_s128x128x64', '_h64x128x16', etc.
mma_cluster_shape_pattern = r"_\d+x\d+x\d+" # Matches MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1')
# Remove instruction shape (e.g., '_s128x128x64', '_h64x128x16')
output = re.sub(instruction_shape_pattern, "", input_string)
# Remove MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1')
output = re.sub(mma_cluster_shape_pattern, "", output)
output = re.sub(mma_cluster_shape_pattern, "", input_string)
return output
@@ -288,7 +283,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
# TODO: randomize beta values for wider coverage
beta_values = [0.5]
is_supported_arch = (arch in ["100a", "101a", "120a"])
is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "120a", "120f"])
is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch
@@ -300,23 +295,23 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
#
sm100_mma_data_type_general = [
'x16gemm_f16_f16_f16_f16_f16',
'x16gemm_f16_f16_f16_void_f16',
'x16gemm_f16_f16_f32_f16_f16',
'x8tf32gemm_f32_f32_f32_f32_f32',
'x16bf16gemm_f32_f32_f32_f32_f32',
'gemm_f16_f16_f16_f16_f16',
'gemm_f16_f16_f16_void_f16',
'gemm_f16_f16_f32_f16_f16',
'tf32gemm_f32_f32_f32_f32_f32',
'bf16gemm_f32_f32_f32_f32_f32',
]
sm100_mma_data_type_runtime_dtype = [
'x32gemm_f4_f4_f32_f32_f32',
'x32gemm_f6_f6_f32_f32_f32',
'x32gemm_f8_f8_f32_f32_f32',
'gemm_f4_f4_f32_f32_f32',
'gemm_f6_f6_f32_f32_f32',
'gemm_f8_f8_f32_f32_f32',
]
sm100_mma_data_type_mergeable = [
'x32gemm_e4m3_e4m3_f32_f32_f32',# mask out one instance for verification
'x32gemm_e2m1_e2m1_f32_f32_f32',
'x32gemm_e3m2_e3m2_f32_f32_f32',
'gemm_e4m3_e4m3_f32_f32_f32',# mask out one instance for verification
'gemm_e2m1_e2m1_f32_f32_f32',
'gemm_e3m2_e3m2_f32_f32_f32',
]
sm100_mma_cluster_size = [
@@ -331,22 +326,15 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
'ntn'
]
sm100_mma_instruction_shape = [
# [0] .1CTA, General
['64x128', '128x128', '128x256'],
# [1] .2CTA, General
['128x128', '256x128', '256x256'],
]
# regex list must be in kernel procedural name order
mergeable_sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
mergeable_sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
mergeable_sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
mergeable_sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
sm100_mma_filter_regex_1sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
sm100_mma_filter_regex_2sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
sm100_mma_filter_regex_1sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
sm100_mma_filter_regex_2sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
#
# Block Scale Gemm
@@ -354,19 +342,19 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
block_scaled_data_type_base = [
# runtime datatypes
'x32gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
'x64gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
'x32gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2',
'x64gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
'x32gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
'gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2',
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
]
block_scaled_data_type_mergeable = [
'x32gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
'x64gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
'x32gemm.*ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
'x64gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
'x32gemm.*ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
'gemm.*ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
'gemm.*ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
]
block_scaled_data_type = block_scaled_data_type_base + block_scaled_data_type_mergeable
@@ -377,56 +365,43 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
]
block_scaled_layouts = ['tnt']
block_scaled_instruction_shape = [
# .1CTA
['128x128', '128x192', '128x256'],
# .2CTA
['256x128', '256x192', '256x256'],
]
# regex list must be in kernel procedural name order
mergeable_block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[0], block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
mergeable_block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[1], block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
mergeable_block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
mergeable_block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[0], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[1], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
if arch == "100a":
if arch == "100a" or arch == "100f":
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({sm100_mma_filter_regex_1sm_runtime})|" \
f"({sm100_mma_filter_regex_2sm_runtime})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})"
elif arch == "101a":
elif arch == "101a" or arch == "101f":
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({sm100_mma_filter_regex_1sm_runtime})|" \
f"({sm100_mma_filter_regex_2sm_runtime})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})"
elif arch == "120a":
elif arch == "120a" or arch == "120f":
# blockscaled sm120_mma kernels
blockscaled_sm120_mma_kernel_cta_tiles = [
[ '128x128' ]
]
# sm120 MMA instruction shapes
blockscaled_sm120_mma_instruction_shapes = [
[ 's16x8x64gemm',
's16x8x32gemm'
]
]
# Restrict to two layouts to reduce L0 build and test time.
blockscaled_sm120_mma_layouts = [ 'tn' ]
filter_regex_blockscaled_sm120_mma = "cutlass3x_sm120_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [blockscaled_sm120_mma_instruction_shapes[0], blockscaled_sm120_mma_kernel_cta_tiles[0], blockscaled_sm120_mma_layouts]]) + ").*"
filter_regex_blockscaled_sm120_mma = "cutlass3x_sm120_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [blockscaled_sm120_mma_kernel_cta_tiles[0], blockscaled_sm120_mma_layouts]]) + ").*"
problem_waves = [0.5, 1.25, 2.5]
kernel_filter = f"({filter_regex_blockscaled_sm120_mma})"
else:
error_message = "unsupported arch, only support sm100a, sm101a, sm120a"
error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm120a, sm120f"
raise Exception(error_message)
# Statically encoded kernels are still added to generated_kernels
@@ -445,14 +420,8 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
]
# Restrict to two layouts to reduce L1 build and test time.
sm100_mma_layouts = ['tnt', 'ntn']
sm100_mma_instruction_shape = [
# .1CTA
['64x128', '128x128', '128x256'],
# .2CTA
['128x128', '256x128', '256x256']
]
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
block_scaled_data_type = [
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
'ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
@@ -463,15 +432,10 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
block_scaled_cluster_size = ['4x4x1', '2x1x1', '0x0x1']
block_scaled_layouts = ['tnt']
block_scaled_instruction_shape = [
# .1CTA
['128x128', '128x192', '128x256'],
# .2CTA
['256x128', '256x192', '256x256'],
]
# regex list must be in kernel procedural name order
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[0], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[1], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({block_scaled_filter_regex_1sm})|" \