v4.2 release. (#2587)

* Fix default cluster callback values to 1 to avoid profiler failure when these values are not set in command line.

* v4.2 release.
This commit is contained in:
Junkai-Wu
2025-08-23 06:11:24 +08:00
committed by GitHub
parent 11cad1f67b
commit a49a78ffef
351 changed files with 28182 additions and 2032 deletions

View File

@@ -279,7 +279,7 @@ def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0, num_groups
def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
):
# For functional testing, we prefer to run reference computing on device if any
reference_device_archs = ["100a"]
reference_device_archs = ["100a", "103a"]
run_reference_on_device = True if arch in reference_device_archs and mode in ["functional_L0", "functional_L1"] else False
profiler_flags_for_verification = "device" if run_reference_on_device else "host"
@@ -287,7 +287,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
# TODO: randomize beta values for wider coverage
beta_values = [0.5]
is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "120a", "120f"])
is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "103a", "110a", "110f", "120a", "120f", "121a", "121f"])
is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch
@@ -306,6 +306,10 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
'bf16gemm_f32_f32_f32_f32_f32',
]
exclude_archs = arch not in ("103a")
if exclude_archs:
sm100_mma_data_type_general.append('gemm_s8_s8_s32_s8_s8')
sm100_mma_data_type_runtime_dtype = [
'gemm.*f4_f4_f32_f32_f32',
'gemm.*f6_f6_f32_f32_f32',
@@ -344,6 +348,11 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
]
sm103_block_scaled_data_type = [
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
]
block_scaled_cluster_size = [
'4x4x1', '2x1x1',
'0x0x1' # dynamic cluster
@@ -354,6 +363,9 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
if arch in ["100a", "100f"]:
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
@@ -361,15 +373,23 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
f"({sm100_mma_filter_regex_2sm_runtime})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})"
elif arch in ["101a", "101f",
]:
elif arch in ["101a", "101f", "110a", "110f"]:
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({sm100_mma_filter_regex_1sm_runtime})|" \
f"({sm100_mma_filter_regex_2sm_runtime})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})"
elif arch in ["120a", "120f"]:
elif arch in ["103a"]:
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({sm100_mma_filter_regex_1sm_runtime})|" \
f"({sm100_mma_filter_regex_2sm_runtime})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})|" \
f"({sm103_block_scaled_filter_regex_1sm})|" \
f"({sm103_block_scaled_filter_regex_2sm})"
elif arch in ["120a", "120f", "121a", "121f"]:
# blockscaled sm120_mma kernels
blockscaled_sm120_mma_kernel_cta_tiles = [
@@ -384,7 +404,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
kernel_filter = f"({filter_regex_blockscaled_sm120_mma})"
else:
error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm120a, sm120f"
error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm110a, sm110f, sm103a, sm120a, sm120f, sm121a, sm121f"
raise Exception(error_message)
elif mode == "functional_L1":
@@ -403,16 +423,27 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
'ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
]
block_scaled_cluster_size = ['4x4x1', '2x1x1', '0x0x1']
sm103_block_scaled_data_type = [
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
]
block_scaled_cluster_size = ['0x0x1']
block_scaled_layouts = ['tnt']
# regex list must be in kernel procedural name order
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})"
f"({block_scaled_filter_regex_2sm})" \
f"({sm103_block_scaled_filter_regex_1sm})|" \
f"({sm103_block_scaled_filter_regex_2sm})"
# CTA tiles for sm120 MMA - only run one tile size to reduce build/test times
sm120_mma_kernel_cta_tiles = [
# h1688, s1688, i16832, i8816
@@ -449,7 +480,10 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
problem_waves = [0.5, 1.25, 2.5]
kernel_filter = f"({filter_regex_sm100_mma})|({filter_regex_sm120_mma})"
if arch in ["120a", "120f", "121a", "121f"]:
kernel_filter = f"({filter_regex_sm120_mma})"
else:
kernel_filter = f"({filter_regex_sm100_mma})"
else:
raise ValueError()