mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
v4.2 release. (#2587)
* Fix default cluster callback values to 1 to avoid profiler failure when these values are not set in command line. * v4.2 release.
This commit is contained in:
@@ -279,7 +279,7 @@ def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0, num_groups
|
||||
def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
):
|
||||
# For functional testing, we prefer to run reference computing on device if any
|
||||
reference_device_archs = ["100a"]
|
||||
reference_device_archs = ["100a", "103a"]
|
||||
run_reference_on_device = True if arch in reference_device_archs and mode in ["functional_L0", "functional_L1"] else False
|
||||
profiler_flags_for_verification = "device" if run_reference_on_device else "host"
|
||||
|
||||
@@ -287,7 +287,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
# TODO: randomize beta values for wider coverage
|
||||
beta_values = [0.5]
|
||||
|
||||
is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "120a", "120f"])
|
||||
is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "103a", "110a", "110f", "120a", "120f", "121a", "121f"])
|
||||
|
||||
is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch
|
||||
|
||||
@@ -306,6 +306,10 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
'bf16gemm_f32_f32_f32_f32_f32',
|
||||
]
|
||||
|
||||
exclude_archs = arch not in ("103a")
|
||||
if exclude_archs:
|
||||
sm100_mma_data_type_general.append('gemm_s8_s8_s32_s8_s8')
|
||||
|
||||
sm100_mma_data_type_runtime_dtype = [
|
||||
'gemm.*f4_f4_f32_f32_f32',
|
||||
'gemm.*f6_f6_f32_f32_f32',
|
||||
@@ -344,6 +348,11 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
|
||||
]
|
||||
|
||||
sm103_block_scaled_data_type = [
|
||||
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
|
||||
]
|
||||
|
||||
block_scaled_cluster_size = [
|
||||
'4x4x1', '2x1x1',
|
||||
'0x0x1' # dynamic cluster
|
||||
@@ -354,6 +363,9 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
|
||||
sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
|
||||
if arch in ["100a", "100f"]:
|
||||
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
@@ -361,15 +373,23 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})"
|
||||
elif arch in ["101a", "101f",
|
||||
]:
|
||||
elif arch in ["101a", "101f", "110a", "110f"]:
|
||||
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({sm100_mma_filter_regex_1sm_runtime})|" \
|
||||
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})"
|
||||
elif arch in ["120a", "120f"]:
|
||||
elif arch in ["103a"]:
|
||||
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({sm100_mma_filter_regex_1sm_runtime})|" \
|
||||
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})|" \
|
||||
f"({sm103_block_scaled_filter_regex_1sm})|" \
|
||||
f"({sm103_block_scaled_filter_regex_2sm})"
|
||||
elif arch in ["120a", "120f", "121a", "121f"]:
|
||||
|
||||
# blockscaled sm120_mma kernels
|
||||
blockscaled_sm120_mma_kernel_cta_tiles = [
|
||||
@@ -384,7 +404,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
|
||||
kernel_filter = f"({filter_regex_blockscaled_sm120_mma})"
|
||||
else:
|
||||
error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm120a, sm120f"
|
||||
error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm110a, sm110f, sm103a, sm120a, sm120f, sm121a, sm121f"
|
||||
raise Exception(error_message)
|
||||
|
||||
elif mode == "functional_L1":
|
||||
@@ -403,16 +423,27 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
'ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
|
||||
]
|
||||
|
||||
block_scaled_cluster_size = ['4x4x1', '2x1x1', '0x0x1']
|
||||
sm103_block_scaled_data_type = [
|
||||
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
||||
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
|
||||
]
|
||||
|
||||
block_scaled_cluster_size = ['0x0x1']
|
||||
block_scaled_layouts = ['tnt']
|
||||
|
||||
# regex list must be in kernel procedural name order
|
||||
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
|
||||
sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
|
||||
filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})"
|
||||
f"({block_scaled_filter_regex_2sm})" \
|
||||
f"({sm103_block_scaled_filter_regex_1sm})|" \
|
||||
f"({sm103_block_scaled_filter_regex_2sm})"
|
||||
# CTA tiles for sm120 MMA - only run one tile size to reduce build/test times
|
||||
sm120_mma_kernel_cta_tiles = [
|
||||
# h1688, s1688, i16832, i8816
|
||||
@@ -449,7 +480,10 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
|
||||
problem_waves = [0.5, 1.25, 2.5]
|
||||
|
||||
kernel_filter = f"({filter_regex_sm100_mma})|({filter_regex_sm120_mma})"
|
||||
if arch in ["120a", "120f", "121a", "121f"]:
|
||||
kernel_filter = f"({filter_regex_sm120_mma})"
|
||||
else:
|
||||
kernel_filter = f"({filter_regex_sm100_mma})"
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user