update 3.8 v2 (#2112)

* update 3.8 v2

* update 3.8

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
Yujia Zhai
2025-02-19 19:03:14 -08:00
committed by GitHub
parent e9627ce55b
commit b84e9802d8
166 changed files with 3986 additions and 4037 deletions

View File

@@ -217,8 +217,7 @@ def CreateGemmUniversal3xOperator(
gemm_op_extra_args["ScaleFactorB"] = data_type["sf_type"]
gemm_op_extra_args["ScaleFactorD"] = { "tensor": TensorDescription(data_type["sfd_type"]["type"], data_type["sfd_type"]["layout"]),
"vector_size" : data_type["sfd_type"]["vector_size"]}
gemm_kind = GemmKind.BlockScaledUniversal3x
assert is_block_scaled(gemm_kind)
A_dtype = data_type["a_type"]
B_dtype = data_type["b_type"]
@@ -254,9 +253,6 @@ def CreateGemmUniversal3xOperator(
return operations
def is_grouped(gemm_kind):
return gemm_kind == GemmKind.GroupedGemmUniversal3x
# Generates 3.0 API based GemmUniversal API kernels. Alignment constraints are folded in with layouts
def CreateSparseGemmUniversal3xOperator(
manifest, layouts, tile_descriptions, data_types,
@@ -6654,11 +6650,13 @@ def get_tma_alignment_elt(data_type : DataType, is_f8f6f4 : bool = True ) -> int
sm100_cluster_shape_1sm = [
[4,4,1]
, DynamicClusterShape
]
sm100_cluster_shape_2sm = [
# cluster_m % 2 == 0 for 2sm
[4,4,1]
, DynamicClusterShape
]
def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
@@ -6718,6 +6716,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
]
cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1], [4,4,1]
, DynamicClusterShape
]
tile_schedulers = [
@@ -6765,6 +6764,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
]
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
@@ -7517,8 +7517,227 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
# SM100 MMA with mixed F4/F6/F8 inputs + without block scale
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
return
def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version):
# layouts for ABC and their alignments.
layouts = [
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
]
instruction_sizes_1sm = [
# [64, 128, 32],
[128, 128, 32],
# [64, 256, 32],
[128, 256, 32],
]
instruction_sizes_2sm = [
# [128, 128, 32],
# [128, 256, 32],
[256, 128, 32],
[256, 256, 32],
]
ab_types = [
DataType.f4, DataType.f6, DataType.f8,
DataType.e2m1, DataType.e3m2, DataType.e4m3,
]
acc_types = [ DataType.f32 ]
tile_schedulers = [
TileSchedulerType.Default, TileSchedulerType.StreamK
]
min_cc = 100
max_cc = 130
epi_type = DataType.f32
math_instructions_1sm = []
is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8)
# Usage:
for instr_size, a_type, b_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, acc_types):
is_runtime_datatype_a = is_runtime_datatype(a_type)
is_runtime_datatype_b = is_runtime_datatype(b_type)
# A/B datatypes should be both static or dynamic
if (is_runtime_datatype_a != is_runtime_datatype_b):
continue
math_instructions_1sm.append(
MathInstruction(
instr_size,
a_type, b_type, acc_type,
OpcodeClass.TensorOp,
MathOperation.multiply_add)
)
math_instructions_2sm = []
for instr_size, a_type, b_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, acc_types):
is_runtime_datatype_a = is_runtime_datatype(a_type)
is_runtime_datatype_b = is_runtime_datatype(b_type)
# A/B datatypes should be both static or dynamic
if (is_runtime_datatype_a != is_runtime_datatype_b):
continue
math_instructions_2sm.append(
MathInstruction(
instr_size,
a_type, b_type, acc_type,
OpcodeClass.TensorOp,
MathOperation.multiply_add)
)
cluster_shapes_1sm = [
# [1,2,1],
[2,1,1],
[1,1,1],
# [1,4,1],
[4,4,1]
, DynamicClusterShape
]
# 1xSM MMA kernels
for math_inst in math_instructions_1sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_1sm:
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
tile_descriptions.append(
TileDescription([
math_inst.instruction_shape[0] * multiplier_1sm[0],
math_inst.instruction_shape[1] * multiplier_1sm[1],
math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]],
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
kernel_data_types = [
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.f32,
"d_type" : DataType.f32,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.f32,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.e5m2,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
}
]
for kernel_data_type in kernel_data_types:
# Filter out some kernel
if ( kernel_data_type["a_type"] == DataType.e4m3 ) and ( kernel_data_type["b_type"] == DataType.e4m3 ) and\
( kernel_data_type["d_type"] == DataType.e5m2 ):
continue
# Update layout alignment
# alignment for d might be different for each kernel_data_type
layouts_copy = copy.deepcopy(layouts)
for layout in layouts_copy:
# alignment for a
layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"])
# alignment for b
layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"])
# alignment for d
layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"])
CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type],
[[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], tile_schedulers=tile_schedulers)
cluster_shapes_2sm = [
[2,1,1],
# [2,2,1],
# [2,4,1],
# [4,1,1],
# [4,2,1],
[4,4,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_2sm:
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
tile_descriptions.append(
TileDescription([
math_inst.instruction_shape[0] * multiplier_2sm[0],
math_inst.instruction_shape[1] * multiplier_2sm[1],
math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]],
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
kernel_data_types = [
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.f32,
"d_type" : DataType.f32,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.f32,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.e5m2,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
}
]
for kernel_data_type in kernel_data_types:
# Filter some kernel
if ( kernel_data_type["a_type"] == DataType.e4m3 ) and ( kernel_data_type["b_type"] == DataType.e4m3 ) and\
( kernel_data_type["d_type"] == DataType.e5m2 ):
continue
# Update layout alignment
# alignment for d might be different for each kernel_data_type
layouts_copy = copy.deepcopy(layouts)
for layout in layouts_copy:
# alignment for a
layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"])
# alignment for b
layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"])
# alignment for d
layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"])
if math_inst.instruction_shape[0] == 128:
CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type],
[[KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], tile_schedulers=tile_schedulers)
else:
CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type],
[[KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]], tile_schedulers=tile_schedulers)
def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x):
# SM100 MMA with mixed F4/F6/F8 inputs + block scale
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
return
@@ -7529,7 +7748,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
]
instruction_sizes_1sm = [
[128, 128, 32], [128, 256, 32], # Mixed F4/F6/F8 block scaled only supports M=128 for 1SM cases
[128, 128, 32], [128, 256, 32], # Block scaled kernels only support M=128 for 1SM cases
]
instruction_sizes_2sm = [
@@ -7670,8 +7889,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
for data_type in data_types:
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]]
, tile_schedulers = tile_schedulers(data_type["sfd_type"])
)
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
cluster_shapes_2sm = [
[2,1,1],
@@ -7766,21 +7984,21 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
if math_inst.instruction_shape[0] == 128:
CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type],
[[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]]
, tile_schedulers = tile_schedulers(data_type["sfd_type"])
)
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
else:
CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type],
[[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]]
, tile_schedulers = tile_schedulers(data_type["sfd_type"])
)
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version):
def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x):
# SM100 MMA with F4 + block scale
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
return
grouped = is_grouped(gemm_kind)
# layouts for ABC and their alignments.
layouts = [
[[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]],
@@ -7805,7 +8023,7 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
def tile_schedulers(sfdtype):
# Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void,
# the epilogue is the traditional linear combination, for which we already have tests with stream-K.
if sfdtype["type"] == DataType.void:
if sfdtype["type"] == DataType.void or grouped:
return [TileSchedulerType.Default]
else:
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
@@ -7826,6 +8044,10 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
if (is_runtime_datatype_a != is_runtime_datatype_b):
continue
# grouped GEMM does not support runtime data type yet
if grouped and (is_runtime_datatype_a or is_runtime_datatype_b):
continue
math_instructions_1sm.append(
MathInstruction(
instr_size,
@@ -7853,6 +8075,10 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
if (is_runtime_datatype_a != is_runtime_datatype_b):
continue
# grouped GEMM does not support runtime data type yet
if grouped and (is_runtime_datatype_a or is_runtime_datatype_b):
continue
math_instructions_2sm.append(
MathInstruction(
instr_size,
@@ -7972,15 +8198,21 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
for data_type in data_types:
if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1):
data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout.
# E2M1 x E2M1, vector size 32, E8
# E2M1 x E2M1, vector size 16, UE4M3
isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1
nvfp4_schedule = [KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]
fp4_schedule = [KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]
epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)
nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped)
fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, grouped)
nvfp4_schedule = [nvfp4_kernel_schedule, epi_schedule]
fp4_schedule = [fp4_kernel_schedule, epi_schedule]
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [nvfp4_schedule]
, tile_schedulers=tile_schedulers(data_type["sfd_type"])
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind
)
if isFp4:
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule]
, tile_schedulers=tile_schedulers(data_type["sfd_type"])
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind
)
cluster_shapes_2sm = [
@@ -8085,18 +8317,20 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
for data_type in data_types:
if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1):
data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout.
# E2M1 x E2M1, vector size 32, E8
# E2M1 x E2M1, vector size 32, E8
isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1
nvfp4_schedule = [KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]
fp4_schedule = [KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]
epi_schedule = EpilogueScheduleType.ScheduleAuto if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm
nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped)
fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, grouped)
nvfp4_schedule = [nvfp4_kernel_schedule, epi_schedule]
fp4_schedule = [fp4_kernel_schedule, epi_schedule]
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [nvfp4_schedule]
, tile_schedulers=tile_schedulers(data_type["sfd_type"])
)
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
if isFp4:
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule]
, tile_schedulers=tile_schedulers(data_type["sfd_type"])
)
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
@@ -8139,6 +8373,7 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
MathOperation.multiply_add)]
cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1], [4,4,1]
, DynamicClusterShape
]
tile_schedulers = [
@@ -8237,6 +8472,7 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
]
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
@@ -8353,6 +8589,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
cluster_shapes_1sm = [
[1,2,1], [1,1,1], [1,4,1], [4,4,1]
, DynamicClusterShape
]
tile_schedulers = [
@@ -8386,6 +8623,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
cluster_shapes_2sm = [
[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
@@ -8431,6 +8669,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
cluster_shapes_1sm = [
[1,2,1], [1,1,1], [4,4,1]
, DynamicClusterShape
]
tile_schedulers = [
@@ -8498,6 +8737,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
cluster_shapes_2sm = [
[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
@@ -8554,6 +8794,125 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
[[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers)
def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
return
# layouts for ABC and their alignments.
layouts = [
[[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]],
[[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]],
[[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]],
[[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]],
[[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]],
]
min_cc = 100
max_cc = 130
epi_type = DataType.f32
math_instructions_1sm = [
MathInstruction(
[128, 256, 32],
DataType.e4m3, DataType.e4m3, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add)]
cluster_shapes_1sm = [
[1,2,1], [2,1,1], [1,1,1], [4,4,1]
, DynamicClusterShape
]
tile_schedulers = [
TileSchedulerType.StreamK,
]
# 1xSM MMA kernels
for math_inst in math_instructions_1sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_1sm:
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
tile_descriptions.append(
TileDescription([
math_inst.instruction_shape[0] * multiplier_1sm[0],
math_inst.instruction_shape[1] * multiplier_1sm[1],
math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]],
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
data_types = [
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.f16,
"d_type" : DataType.e4m3,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
}]
# Set alignment d based on Destination format.
for layout in layouts:
layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]]
for data_type in data_types:
if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\
( data_type["d_type"] == DataType.e5m2 ):
continue
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]],
tile_schedulers=tile_schedulers)
# 2xSM MMA kernels
math_instructions_2sm = [
MathInstruction(
[256, 256, 32],
DataType.e4m3, DataType.e4m3, DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add),
]
cluster_shapes_2sm = [
[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_2sm:
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
tile_descriptions.append(
TileDescription([
math_inst.instruction_shape[0] * multiplier_2sm[0],
math_inst.instruction_shape[1] * multiplier_2sm[1],
math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]],
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
data_types = [
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.f16,
"d_type" : DataType.e4m3,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
}]
# Set alignment d based on Destination format.
for layout in layouts:
layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]]
for data_type in data_types:
if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\
( data_type["d_type"] == DataType.e5m2 ):
continue
if math_inst.instruction_shape[0] == 128:
epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm
else:
epi_schedule = EpilogueScheduleType.ScheduleAuto
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers)
def GenerateSM100(manifest, cuda_version):
#
@@ -8570,13 +8929,19 @@ def GenerateSM100(manifest, cuda_version):
GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version)
# grouped GEMM
GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x)
GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x)
GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x)
GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x)
GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version)
# StreamK is included in regular generation
GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version)
#
# Block Scaled Gemm
#
GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version)
GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version)
GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x)
###################################################################################################
@@ -8955,8 +9320,8 @@ def GenerateSM90(manifest, cuda_version):
GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version)
GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version)
GenerateSM90_TensorOp_1684(manifest, cuda_version)
GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x)
GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x)
GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x)
GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x)
GenerateSM90_TensorOp_1684_complex(manifest, cuda_version)
GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version)
GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version)