mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
v3.9 (#2185)
* v3.8 update x * fix blackwell gg * doc change * doc change * doc change --------- Co-authored-by: yuzhai <yuzhai@nvidia.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com> Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
This commit is contained in:
@@ -286,7 +286,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
# TODO: randomize beta values for wider coverage
|
||||
beta_values = [0.5]
|
||||
|
||||
is_supported_arch = (arch in ["100a"])
|
||||
is_supported_arch = (arch in ["100a", "101a", "120a"])
|
||||
|
||||
is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch
|
||||
|
||||
@@ -395,8 +395,36 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})"
|
||||
elif arch == "101a":
|
||||
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({sm100_mma_filter_regex_1sm_runtime})|" \
|
||||
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})"
|
||||
elif arch == "120a":
|
||||
|
||||
# blockscaled sm120_mma kernels
|
||||
blockscaled_sm120_mma_kernel_cta_tiles = [
|
||||
[ '128x128' ]
|
||||
]
|
||||
|
||||
# sm120 MMA instruction shapes
|
||||
blockscaled_sm120_mma_instruction_shapes = [
|
||||
[ 's16x8x64gemm',
|
||||
's16x8x32gemm'
|
||||
]
|
||||
]
|
||||
|
||||
# Restrict to two layouts to reduce L0 build and test time.
|
||||
blockscaled_sm120_mma_layouts = [ 'tn' ]
|
||||
filter_regex_blockscaled_sm120_mma = "cutlass3x_sm120_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [blockscaled_sm120_mma_instruction_shapes[0], blockscaled_sm120_mma_kernel_cta_tiles[0], blockscaled_sm120_mma_layouts]]) + ").*"
|
||||
|
||||
problem_waves = [0.5, 1.25, 2.5]
|
||||
|
||||
kernel_filter = f"({filter_regex_blockscaled_sm120_mma})"
|
||||
else:
|
||||
error_message = "unsupported arch, only support sm100a"
|
||||
error_message = "unsupported arch, only support sm100a, sm101a, sm120a"
|
||||
raise Exception(error_message)
|
||||
|
||||
# Statically encoded kernels are still added to generated_kernels
|
||||
@@ -446,8 +474,8 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})|"
|
||||
# CTA tiles for super MMA - only run one tile size to reduce build/test times
|
||||
supermma_kernel_cta_tiles = [
|
||||
# CTA tiles for sm120 MMA - only run one tile size to reduce build/test times
|
||||
sm120_mma_kernel_cta_tiles = [
|
||||
# h1688, s1688, i16832, i8816
|
||||
[ '256x128' ],
|
||||
# d884, c1688,
|
||||
@@ -458,8 +486,8 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
[ '64x64' ]
|
||||
]
|
||||
|
||||
# super MMA instruction shapes, planar complex type excluded as they are not required
|
||||
supermma_instruction_shapes = [
|
||||
# sm120 MMA instruction shapes, planar complex type excluded as they are not required
|
||||
sm120_mma_instruction_shapes = [
|
||||
[ 'h1688gemm_(?!planar_complex)',
|
||||
's1688gemm_f16',
|
||||
's1688gemm_bf16',
|
||||
@@ -473,16 +501,16 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
]
|
||||
|
||||
# It's not pretty, but not sure why different instructions support different tile sizes.
|
||||
filter_regex_supermma_0 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [supermma_instruction_shapes[0], supermma_kernel_cta_tiles[0]]]) + ").*"
|
||||
filter_regex_supermma_1 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [supermma_instruction_shapes[1], supermma_kernel_cta_tiles[1]]]) + ").*"
|
||||
filter_regex_supermma_2 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [supermma_instruction_shapes[2], supermma_kernel_cta_tiles[2]]]) + ").*"
|
||||
filter_regex_supermma_3 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [supermma_instruction_shapes[3], supermma_kernel_cta_tiles[3]]]) + ").*"
|
||||
filter_regex_sm120_mma_0 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[0], sm120_mma_kernel_cta_tiles[0]]]) + ").*"
|
||||
filter_regex_sm120_mma_1 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[1], sm120_mma_kernel_cta_tiles[1]]]) + ").*"
|
||||
filter_regex_sm120_mma_2 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[2], sm120_mma_kernel_cta_tiles[2]]]) + ").*"
|
||||
filter_regex_sm120_mma_3 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[3], sm120_mma_kernel_cta_tiles[3]]]) + ").*"
|
||||
|
||||
filter_regex_supermma = f"({filter_regex_supermma_0})|({filter_regex_supermma_1})|({filter_regex_supermma_2})|({filter_regex_supermma_3})"
|
||||
filter_regex_sm120_mma = f"({filter_regex_sm120_mma_0})|({filter_regex_sm120_mma_1})|({filter_regex_sm120_mma_2})|({filter_regex_sm120_mma_3})"
|
||||
|
||||
problem_waves = [0.5, 1.25, 2.5]
|
||||
|
||||
kernel_filter = f"({filter_regex_sm100_mma})|({filter_regex_supermma})"
|
||||
kernel_filter = f"({filter_regex_sm100_mma})|({filter_regex_sm120_mma})"
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
@@ -494,6 +522,8 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
|
||||
if is_runtime_datatype_enabled:
|
||||
mergeable_kernel_filter_re = re.compile(mergeable_kernel_filter)
|
||||
|
||||
|
||||
kernel_filter_re = re.compile(kernel_filter)
|
||||
testcase_counter = 0
|
||||
kernels_emitted = 0
|
||||
@@ -630,6 +660,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
max_k = (cta_tile_shape_k*8) - alignment_ab_max
|
||||
problem_shapes_k = [min_k, max_k]
|
||||
sm_count = 16
|
||||
swizzle_sizes = [0]
|
||||
# Larger k and less than half wave trigger streamk +separate reduction case to be generated
|
||||
if 'stream_k' in kernel_name:
|
||||
problem_shapes_k = [max_k, cta_tile_shape_k*32]
|
||||
@@ -649,145 +680,147 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
for beta in beta_values:
|
||||
for cluster_shape in runtime_cluster_shapes:
|
||||
for runtime_input_datatype in runtime_input_datatypes:
|
||||
grid_size = waves * sm_count
|
||||
cluster_shape_m, cluster_shape_n, cluster_shape_k = tuple(cluster_shape)
|
||||
if cluster_shape_m >= cluster_shape_n:
|
||||
grid_m = cluster_shape_m
|
||||
grid_n = grid_size / grid_m
|
||||
grid_n = max( int((grid_n + cluster_shape_n - 1) / cluster_shape_n) * cluster_shape_n, 1)
|
||||
else:
|
||||
grid_n = cluster_shape_n
|
||||
grid_m = grid_size / grid_n
|
||||
grid_m = max( int((grid_m + cluster_shape_m - 1) / cluster_shape_m) * cluster_shape_m, 1)
|
||||
for swizzle_size in swizzle_sizes:
|
||||
grid_size = waves * sm_count
|
||||
cluster_shape_m, cluster_shape_n, cluster_shape_k = tuple(cluster_shape)
|
||||
if cluster_shape_m >= cluster_shape_n:
|
||||
grid_m = cluster_shape_m
|
||||
grid_n = grid_size / grid_m
|
||||
grid_n = max( int((grid_n + cluster_shape_n - 1) / cluster_shape_n) * cluster_shape_n, 1)
|
||||
else:
|
||||
grid_n = cluster_shape_n
|
||||
grid_m = grid_size / grid_n
|
||||
grid_m = max( int((grid_m + cluster_shape_m - 1) / cluster_shape_m) * cluster_shape_m, 1)
|
||||
|
||||
verification_required = False
|
||||
if mode == "functional_L0" or mode == "functional_L1":
|
||||
if '_void_' not in kernel_name:
|
||||
verification_required = True
|
||||
verification_required = False
|
||||
if mode == "functional_L0" or mode == "functional_L1":
|
||||
if '_void_' not in kernel_name:
|
||||
verification_required = True
|
||||
|
||||
m = max(int(grid_m * cta_tile_shape_m), alignment_ab_max)
|
||||
n = max(int(grid_n * cta_tile_shape_n), alignment_ab_max)
|
||||
k = int(k)
|
||||
m = max(int(grid_m * cta_tile_shape_m), alignment_ab_max)
|
||||
n = max(int(grid_n * cta_tile_shape_n), alignment_ab_max)
|
||||
k = int(k)
|
||||
|
||||
# For functional testing, we want to perturb just a little from even shapes.
|
||||
# Only do this if the perturbation does not cause one of the dimensions of the
|
||||
# problem size to go to zero. This can occur for blockscaling kernels for which
|
||||
# the alignment requirements for A and B can be quite large (e.g., 256).
|
||||
if m > alignment_shift_m:
|
||||
m -= alignment_shift_m
|
||||
if n > alignment_shift_n:
|
||||
n -= alignment_shift_n
|
||||
# For functional testing, we want to perturb just a little from even shapes.
|
||||
# Only do this if the perturbation does not cause one of the dimensions of the
|
||||
# problem size to go to zero. This can occur for blockscaling kernels for which
|
||||
# the alignment requirements for A and B can be quite large (e.g., 256).
|
||||
if m > alignment_shift_m:
|
||||
m -= alignment_shift_m
|
||||
if n > alignment_shift_n:
|
||||
n -= alignment_shift_n
|
||||
|
||||
if '_n32t32_' in kernel_name:
|
||||
continue
|
||||
batch_count = 1
|
||||
if mode == "functional_L0" or mode == "functional_L1" :
|
||||
if index_waves == 0 and index_k == 0 :
|
||||
batch_count = 3 if mode == "functional_L0" else 5
|
||||
gemm_op = "gemm"
|
||||
if '_n32t32_' in kernel_name:
|
||||
continue
|
||||
batch_count = 1
|
||||
if mode == "functional_L0" or mode == "functional_L1" :
|
||||
if index_waves == 0 and index_k == 0 :
|
||||
batch_count = 3 if mode == "functional_L0" else 5
|
||||
gemm_op = "gemm"
|
||||
|
||||
profiler_reference_computing_override = profiler_reference_computing
|
||||
if "bstensorop" in kernel_name:
|
||||
profiler_reference_computing_override = "--mode=trace"
|
||||
gemm_op = "block_scaled_gemm"
|
||||
profiler_reference_computing_override = profiler_reference_computing
|
||||
if "bstensorop" in kernel_name:
|
||||
profiler_reference_computing_override = "--mode=trace"
|
||||
gemm_op = "block_scaled_gemm"
|
||||
|
||||
problem_size_category = ['smallK','largeK'][index_k] + '_' + ['beta==0','beta!=0'][bool(beta)]
|
||||
problem_size_category = ['smallK','largeK'][index_k] + '_' + ['beta==0','beta!=0'][bool(beta)]
|
||||
|
||||
assert m > 0 and n > 0 and k > 0
|
||||
assert m > 0 and n > 0 and k > 0
|
||||
|
||||
# Emit per-testcase metadata for perf testing usage, eventually in perf database
|
||||
metadata_dict = {
|
||||
"input_params": {
|
||||
'problem_size_category' : problem_size_category,
|
||||
'operation' : _getSubOperationType(operation),
|
||||
'datatype' : data_types,
|
||||
'layout' : layout3x,
|
||||
'm' : m,
|
||||
'n' : n,
|
||||
'k' : k,
|
||||
'beta' : beta,
|
||||
'flops_per_byte' : _computeFlopsPerByte(operation, m, n, k, batch_count, beta)
|
||||
},
|
||||
"runtime_params": {
|
||||
'ctas_per_mma_instruction' : ctas_per_mma_instruction,
|
||||
'tilesize_m' : cta_tile_shape_m,
|
||||
'tilesize_n' : cta_tile_shape_n,
|
||||
'tilesize_k' : cta_tile_shape_k,
|
||||
'cluster_shape_m' : cluster_shape_m,
|
||||
'cluster_shape_n' : cluster_shape_n,
|
||||
}
|
||||
}
|
||||
|
||||
cluster_m_fallback = ctas_per_mma_instruction if dynamic_cluster else cluster_shape_m
|
||||
cluster_n_fallback = 1 if dynamic_cluster else cluster_shape_n
|
||||
cluster_k_fallback = 1 if dynamic_cluster else cluster_shape_k
|
||||
|
||||
|
||||
if dynamic_datatype:
|
||||
runtime_datatype_a, runtime_datatype_b = tuple(runtime_input_datatype)
|
||||
metadata_dict["runtime_params"]["runtime_datatype_a"] = runtime_datatype_a
|
||||
metadata_dict["runtime_params"]["runtime_datatype_b"] = runtime_datatype_b
|
||||
|
||||
testcase_metadata = [
|
||||
f"cutlass_profiler --operation={gemm_op} {profiler_reference_computing_override} --error-on-no-match --error-if-nothing-is-profiled" +
|
||||
f" --kernels={kernel_name}" +
|
||||
f" --m={str(m)}" +
|
||||
f" --n={str(n)}" +
|
||||
f" --k={str(k)}" +
|
||||
f" --cluster_m={str(cluster_shape_m)}" +
|
||||
f" --cluster_n={str(cluster_shape_n)}" +
|
||||
f" --cluster_k={str(cluster_shape_k)}" +
|
||||
f" --cluster_m_fallback={str(cluster_m_fallback)}" +
|
||||
f" --cluster_n_fallback={str(cluster_n_fallback)}" +
|
||||
f" --cluster_k_fallback={str(cluster_k_fallback)}" +
|
||||
f" --beta={str(beta)}" +
|
||||
f" --batch_count={str(batch_count)}" +
|
||||
f" --verification-required={str(verification_required).lower()}"
|
||||
] \
|
||||
|
||||
output_dynamic_datatype = dynamic_datatype
|
||||
if output_dynamic_datatype:
|
||||
testcase_metadata[0] += (f" --runtime_input_datatype_a={runtime_datatype_a}" +
|
||||
f" --runtime_input_datatype_b={runtime_datatype_b}")
|
||||
|
||||
testcase_metadata.append(json.dumps(metadata_dict))
|
||||
testlist_csv_rows.append(testcase_metadata)
|
||||
testcase_counter += 1
|
||||
|
||||
alpha = 1.0
|
||||
|
||||
if dynamic_datatype:
|
||||
hashed_kernel_name = transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b)
|
||||
|
||||
# If kernel_name is new, initialize its feature set with defaults
|
||||
if hashed_kernel_name not in kernel_features:
|
||||
kernel_features[hashed_kernel_name] = {
|
||||
"is_support_dynamic_cluster": False,
|
||||
"is_support_dynamic_datatype": False,
|
||||
# Emit per-testcase metadata for perf testing usage, eventually in perf database
|
||||
metadata_dict = {
|
||||
"input_params": {
|
||||
'problem_size_category' : problem_size_category,
|
||||
'operation' : _getSubOperationType(operation),
|
||||
'datatype' : data_types,
|
||||
'layout' : layout3x,
|
||||
'm' : m,
|
||||
'n' : n,
|
||||
'k' : k,
|
||||
'beta' : beta,
|
||||
'flops_per_byte' : _computeFlopsPerByte(operation, m, n, k, batch_count, beta)
|
||||
},
|
||||
"runtime_params": {
|
||||
'ctas_per_mma_instruction' : ctas_per_mma_instruction,
|
||||
'tilesize_m' : cta_tile_shape_m,
|
||||
'tilesize_n' : cta_tile_shape_n,
|
||||
'tilesize_k' : cta_tile_shape_k,
|
||||
'cluster_shape_m' : cluster_shape_m,
|
||||
'cluster_shape_n' : cluster_shape_n,
|
||||
}
|
||||
}
|
||||
|
||||
# Update features for the hashed kernel name
|
||||
kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] |= dynamic_cluster
|
||||
kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] |= dynamic_datatype
|
||||
cluster_m_fallback = ctas_per_mma_instruction if dynamic_cluster else cluster_shape_m
|
||||
cluster_n_fallback = 1 if dynamic_cluster else cluster_shape_n
|
||||
cluster_k_fallback = 1 if dynamic_cluster else cluster_shape_k
|
||||
|
||||
if hashed_kernel_name not in auditlist_csv_params_map:
|
||||
auditlist_csv_params_map[hashed_kernel_name] = []
|
||||
|
||||
audit_row_params = get_kernel_params(
|
||||
operation,
|
||||
hashed_kernel_name,
|
||||
(cluster_shape_m, cluster_shape_n, cluster_shape_k),
|
||||
(cluster_m_fallback, cluster_n_fallback, cluster_k_fallback),
|
||||
(m, n, k, batch_count),
|
||||
alpha, beta,
|
||||
dynamic_datatype, dynamic_cluster
|
||||
)
|
||||
if dynamic_datatype:
|
||||
runtime_datatype_a, runtime_datatype_b = tuple(runtime_input_datatype)
|
||||
metadata_dict["runtime_params"]["runtime_datatype_a"] = runtime_datatype_a
|
||||
metadata_dict["runtime_params"]["runtime_datatype_b"] = runtime_datatype_b
|
||||
|
||||
auditlist_csv_params_map[hashed_kernel_name].append(audit_row_params)
|
||||
testcase_metadata = [
|
||||
f"cutlass_profiler --operation={gemm_op} {profiler_reference_computing_override} --error-on-no-match --error-if-nothing-is-profiled" +
|
||||
f" --kernels={kernel_name}" +
|
||||
f" --m={str(m)}" +
|
||||
f" --n={str(n)}" +
|
||||
f" --k={str(k)}" +
|
||||
f" --cluster_m={str(cluster_shape_m)}" +
|
||||
f" --cluster_n={str(cluster_shape_n)}" +
|
||||
f" --cluster_k={str(cluster_shape_k)}" +
|
||||
f" --cluster_m_fallback={str(cluster_m_fallback)}" +
|
||||
f" --cluster_n_fallback={str(cluster_n_fallback)}" +
|
||||
f" --cluster_k_fallback={str(cluster_k_fallback)}" +
|
||||
f" --beta={str(beta)}" +
|
||||
f" --batch_count={str(batch_count)}" +
|
||||
f" --swizzle_size={str(swizzle_size)}" +
|
||||
f" --verification-required={str(verification_required).lower()}"
|
||||
] \
|
||||
|
||||
if hashed_kernel_name not in auditlist_csv_map:
|
||||
audit_row = get_kernel_features(operation, hashed_kernel_name, dynamic_datatype, runtime_input_datatype)
|
||||
auditlist_csv_map[hashed_kernel_name] = audit_row
|
||||
output_dynamic_datatype = dynamic_datatype
|
||||
if output_dynamic_datatype:
|
||||
testcase_metadata[0] += (f" --runtime_input_datatype_a={runtime_datatype_a}" +
|
||||
f" --runtime_input_datatype_b={runtime_datatype_b}")
|
||||
|
||||
testcase_metadata.append(json.dumps(metadata_dict))
|
||||
testlist_csv_rows.append(testcase_metadata)
|
||||
testcase_counter += 1
|
||||
|
||||
alpha = 1.0
|
||||
|
||||
if dynamic_datatype:
|
||||
hashed_kernel_name = transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b)
|
||||
|
||||
# If kernel_name is new, initialize its feature set with defaults
|
||||
if hashed_kernel_name not in kernel_features:
|
||||
kernel_features[hashed_kernel_name] = {
|
||||
"is_support_dynamic_cluster": False,
|
||||
"is_support_dynamic_datatype": False,
|
||||
}
|
||||
|
||||
# Update features for the hashed kernel name
|
||||
kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] |= dynamic_cluster
|
||||
kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] |= dynamic_datatype
|
||||
|
||||
if hashed_kernel_name not in auditlist_csv_params_map:
|
||||
auditlist_csv_params_map[hashed_kernel_name] = []
|
||||
|
||||
audit_row_params = get_kernel_params(
|
||||
operation,
|
||||
hashed_kernel_name,
|
||||
(cluster_shape_m, cluster_shape_n, cluster_shape_k),
|
||||
(cluster_m_fallback, cluster_n_fallback, cluster_k_fallback),
|
||||
(m, n, k, batch_count),
|
||||
alpha, beta,
|
||||
dynamic_datatype, dynamic_cluster
|
||||
)
|
||||
|
||||
auditlist_csv_params_map[hashed_kernel_name].append(audit_row_params)
|
||||
|
||||
if hashed_kernel_name not in auditlist_csv_map:
|
||||
audit_row = get_kernel_features(operation, hashed_kernel_name, dynamic_datatype, runtime_input_datatype)
|
||||
auditlist_csv_map[hashed_kernel_name] = audit_row
|
||||
|
||||
with open(outfile_name, 'w') as testlist_csv:
|
||||
csv_writer = csv.writer(testlist_csv, delimiter=',')
|
||||
@@ -826,7 +859,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
for kernel_name in kernel_name_set:
|
||||
file.write(kernel_name + "\n")
|
||||
|
||||
# Sort L0 and L1 kernel list and csv file to avoid mixing cutlass3.x kernels and superMMA kernels in cutlass2.x generated together.
|
||||
# Sort L0 and L1 kernel list and csv file to avoid mixing cutlass3.x kernels and sm120_mma kernels in cutlass2.x generated together.
|
||||
if mode == "functional_L0" or mode == "functional_L1":
|
||||
# Sort the .csv file
|
||||
outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm.csv")
|
||||
|
||||
Reference in New Issue
Block a user