mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-05 14:11:18 +00:00
update 3.8 v2 (#2112)
* update 3.8 v2 * update 3.8 --------- Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
@@ -217,8 +217,7 @@ def CreateGemmUniversal3xOperator(
|
||||
gemm_op_extra_args["ScaleFactorB"] = data_type["sf_type"]
|
||||
gemm_op_extra_args["ScaleFactorD"] = { "tensor": TensorDescription(data_type["sfd_type"]["type"], data_type["sfd_type"]["layout"]),
|
||||
"vector_size" : data_type["sfd_type"]["vector_size"]}
|
||||
gemm_kind = GemmKind.BlockScaledUniversal3x
|
||||
|
||||
assert is_block_scaled(gemm_kind)
|
||||
|
||||
A_dtype = data_type["a_type"]
|
||||
B_dtype = data_type["b_type"]
|
||||
@@ -254,9 +253,6 @@ def CreateGemmUniversal3xOperator(
|
||||
|
||||
return operations
|
||||
|
||||
def is_grouped(gemm_kind):
|
||||
return gemm_kind == GemmKind.GroupedGemmUniversal3x
|
||||
|
||||
# Generates 3.0 API based GemmUniversal API kernels. Alignment constraints are folded in with layouts
|
||||
def CreateSparseGemmUniversal3xOperator(
|
||||
manifest, layouts, tile_descriptions, data_types,
|
||||
@@ -6654,11 +6650,13 @@ def get_tma_alignment_elt(data_type : DataType, is_f8f6f4 : bool = True ) -> int
|
||||
|
||||
sm100_cluster_shape_1sm = [
|
||||
[4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
sm100_cluster_shape_2sm = [
|
||||
# cluster_m % 2 == 0 for 2sm
|
||||
[4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
@@ -6718,6 +6716,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
]
|
||||
|
||||
cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
@@ -6765,6 +6764,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
]
|
||||
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
@@ -7517,8 +7517,227 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
|
||||
|
||||
def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
|
||||
# SM100 MMA with mixed F4/F6/F8 inputs + without block scale
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
|
||||
return
|
||||
|
||||
def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version):
|
||||
# layouts for ABC and their alignments.
|
||||
layouts = [
|
||||
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
|
||||
]
|
||||
|
||||
instruction_sizes_1sm = [
|
||||
# [64, 128, 32],
|
||||
[128, 128, 32],
|
||||
# [64, 256, 32],
|
||||
[128, 256, 32],
|
||||
]
|
||||
|
||||
instruction_sizes_2sm = [
|
||||
# [128, 128, 32],
|
||||
# [128, 256, 32],
|
||||
[256, 128, 32],
|
||||
[256, 256, 32],
|
||||
]
|
||||
|
||||
ab_types = [
|
||||
DataType.f4, DataType.f6, DataType.f8,
|
||||
DataType.e2m1, DataType.e3m2, DataType.e4m3,
|
||||
]
|
||||
|
||||
acc_types = [ DataType.f32 ]
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default, TileSchedulerType.StreamK
|
||||
]
|
||||
|
||||
min_cc = 100
|
||||
max_cc = 130
|
||||
|
||||
epi_type = DataType.f32
|
||||
|
||||
math_instructions_1sm = []
|
||||
|
||||
is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8)
|
||||
|
||||
# Usage:
|
||||
|
||||
for instr_size, a_type, b_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, acc_types):
|
||||
is_runtime_datatype_a = is_runtime_datatype(a_type)
|
||||
is_runtime_datatype_b = is_runtime_datatype(b_type)
|
||||
|
||||
# A/B datatypes should be both static or dynamic
|
||||
if (is_runtime_datatype_a != is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
math_instructions_1sm.append(
|
||||
MathInstruction(
|
||||
instr_size,
|
||||
a_type, b_type, acc_type,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add)
|
||||
)
|
||||
|
||||
math_instructions_2sm = []
|
||||
|
||||
for instr_size, a_type, b_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, acc_types):
|
||||
is_runtime_datatype_a = is_runtime_datatype(a_type)
|
||||
is_runtime_datatype_b = is_runtime_datatype(b_type)
|
||||
|
||||
# A/B datatypes should be both static or dynamic
|
||||
if (is_runtime_datatype_a != is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
math_instructions_2sm.append(
|
||||
MathInstruction(
|
||||
instr_size,
|
||||
a_type, b_type, acc_type,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add)
|
||||
)
|
||||
|
||||
cluster_shapes_1sm = [
|
||||
# [1,2,1],
|
||||
[2,1,1],
|
||||
[1,1,1],
|
||||
# [1,4,1],
|
||||
[4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
# 1xSM MMA kernels
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_1sm:
|
||||
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
math_inst.instruction_shape[0] * multiplier_1sm[0],
|
||||
math_inst.instruction_shape[1] * multiplier_1sm[1],
|
||||
math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
|
||||
|
||||
kernel_data_types = [
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.f32,
|
||||
"d_type" : DataType.f32,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.f32,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.e5m2,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
}
|
||||
]
|
||||
|
||||
for kernel_data_type in kernel_data_types:
|
||||
# Filter out some kernel
|
||||
if ( kernel_data_type["a_type"] == DataType.e4m3 ) and ( kernel_data_type["b_type"] == DataType.e4m3 ) and\
|
||||
( kernel_data_type["d_type"] == DataType.e5m2 ):
|
||||
continue
|
||||
|
||||
# Update layout alignment
|
||||
# alignment for d might be different for each kernel_data_type
|
||||
layouts_copy = copy.deepcopy(layouts)
|
||||
for layout in layouts_copy:
|
||||
# alignment for a
|
||||
layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"])
|
||||
# alignment for b
|
||||
layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"])
|
||||
# alignment for d
|
||||
layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"])
|
||||
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type],
|
||||
[[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], tile_schedulers=tile_schedulers)
|
||||
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1],
|
||||
# [2,2,1],
|
||||
# [2,4,1],
|
||||
# [4,1,1],
|
||||
# [4,2,1],
|
||||
[4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
math_inst.instruction_shape[0] * multiplier_2sm[0],
|
||||
math_inst.instruction_shape[1] * multiplier_2sm[1],
|
||||
math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
|
||||
|
||||
kernel_data_types = [
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.f32,
|
||||
"d_type" : DataType.f32,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.f32,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.e5m2,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
}
|
||||
]
|
||||
|
||||
for kernel_data_type in kernel_data_types:
|
||||
# Filter some kernel
|
||||
if ( kernel_data_type["a_type"] == DataType.e4m3 ) and ( kernel_data_type["b_type"] == DataType.e4m3 ) and\
|
||||
( kernel_data_type["d_type"] == DataType.e5m2 ):
|
||||
continue
|
||||
|
||||
# Update layout alignment
|
||||
# alignment for d might be different for each kernel_data_type
|
||||
layouts_copy = copy.deepcopy(layouts)
|
||||
for layout in layouts_copy:
|
||||
# alignment for a
|
||||
layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"])
|
||||
# alignment for b
|
||||
layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"])
|
||||
# alignment for d
|
||||
layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"])
|
||||
|
||||
if math_inst.instruction_shape[0] == 128:
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type],
|
||||
[[KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], tile_schedulers=tile_schedulers)
|
||||
else:
|
||||
CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type],
|
||||
[[KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]], tile_schedulers=tile_schedulers)
|
||||
|
||||
def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x):
|
||||
# SM100 MMA with mixed F4/F6/F8 inputs + block scale
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
|
||||
return
|
||||
@@ -7529,7 +7748,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
]
|
||||
|
||||
instruction_sizes_1sm = [
|
||||
[128, 128, 32], [128, 256, 32], # Mixed F4/F6/F8 block scaled only supports M=128 for 1SM cases
|
||||
[128, 128, 32], [128, 256, 32], # Block scaled kernels only support M=128 for 1SM cases
|
||||
]
|
||||
|
||||
instruction_sizes_2sm = [
|
||||
@@ -7670,8 +7889,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
for data_type in data_types:
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]]
|
||||
, tile_schedulers = tile_schedulers(data_type["sfd_type"])
|
||||
)
|
||||
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
|
||||
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1],
|
||||
@@ -7766,21 +7984,21 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
if math_inst.instruction_shape[0] == 128:
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type],
|
||||
[[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]]
|
||||
, tile_schedulers = tile_schedulers(data_type["sfd_type"])
|
||||
)
|
||||
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
|
||||
else:
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type],
|
||||
[[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]]
|
||||
, tile_schedulers = tile_schedulers(data_type["sfd_type"])
|
||||
)
|
||||
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
|
||||
|
||||
|
||||
|
||||
def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version):
|
||||
def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x):
|
||||
# SM100 MMA with F4 + block scale
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
|
||||
return
|
||||
|
||||
grouped = is_grouped(gemm_kind)
|
||||
|
||||
# layouts for ABC and their alignments.
|
||||
layouts = [
|
||||
[[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]],
|
||||
@@ -7805,7 +8023,7 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
def tile_schedulers(sfdtype):
|
||||
# Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void,
|
||||
# the epilogue is the traditional linear combination, for which we already have tests with stream-K.
|
||||
if sfdtype["type"] == DataType.void:
|
||||
if sfdtype["type"] == DataType.void or grouped:
|
||||
return [TileSchedulerType.Default]
|
||||
else:
|
||||
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
|
||||
@@ -7826,6 +8044,10 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
if (is_runtime_datatype_a != is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
# grouped GEMM does not support runtime data type yet
|
||||
if grouped and (is_runtime_datatype_a or is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
math_instructions_1sm.append(
|
||||
MathInstruction(
|
||||
instr_size,
|
||||
@@ -7853,6 +8075,10 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
if (is_runtime_datatype_a != is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
# grouped GEMM does not support runtime data type yet
|
||||
if grouped and (is_runtime_datatype_a or is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
math_instructions_2sm.append(
|
||||
MathInstruction(
|
||||
instr_size,
|
||||
@@ -7972,15 +8198,21 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
for data_type in data_types:
|
||||
if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1):
|
||||
data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout.
|
||||
# E2M1 x E2M1, vector size 32, E8
|
||||
# E2M1 x E2M1, vector size 16, UE4M3
|
||||
isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1
|
||||
nvfp4_schedule = [KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]
|
||||
fp4_schedule = [KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]
|
||||
epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)
|
||||
nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped)
|
||||
fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, grouped)
|
||||
|
||||
nvfp4_schedule = [nvfp4_kernel_schedule, epi_schedule]
|
||||
fp4_schedule = [fp4_kernel_schedule, epi_schedule]
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [nvfp4_schedule]
|
||||
, tile_schedulers=tile_schedulers(data_type["sfd_type"])
|
||||
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind
|
||||
)
|
||||
if isFp4:
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule]
|
||||
, tile_schedulers=tile_schedulers(data_type["sfd_type"])
|
||||
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind
|
||||
)
|
||||
|
||||
cluster_shapes_2sm = [
|
||||
@@ -8085,18 +8317,20 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
for data_type in data_types:
|
||||
if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1):
|
||||
data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout.
|
||||
# E2M1 x E2M1, vector size 32, E8
|
||||
# E2M1 x E2M1, vector size 32, E8
|
||||
isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1
|
||||
|
||||
nvfp4_schedule = [KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]
|
||||
fp4_schedule = [KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]
|
||||
epi_schedule = EpilogueScheduleType.ScheduleAuto if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm
|
||||
nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped)
|
||||
fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, grouped)
|
||||
|
||||
nvfp4_schedule = [nvfp4_kernel_schedule, epi_schedule]
|
||||
fp4_schedule = [fp4_kernel_schedule, epi_schedule]
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [nvfp4_schedule]
|
||||
, tile_schedulers=tile_schedulers(data_type["sfd_type"])
|
||||
)
|
||||
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
|
||||
if isFp4:
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule]
|
||||
, tile_schedulers=tile_schedulers(data_type["sfd_type"])
|
||||
)
|
||||
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
|
||||
|
||||
|
||||
|
||||
@@ -8139,6 +8373,7 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
MathOperation.multiply_add)]
|
||||
|
||||
cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
@@ -8237,6 +8472,7 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
]
|
||||
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
@@ -8353,6 +8589,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
|
||||
cluster_shapes_1sm = [
|
||||
[1,2,1], [1,1,1], [1,4,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
@@ -8386,6 +8623,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
@@ -8431,6 +8669,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
|
||||
cluster_shapes_1sm = [
|
||||
[1,2,1], [1,1,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
@@ -8498,6 +8737,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
@@ -8554,6 +8794,125 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
[[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers)
|
||||
|
||||
|
||||
def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
|
||||
return
|
||||
|
||||
# layouts for ABC and their alignments.
|
||||
layouts = [
|
||||
[[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]],
|
||||
[[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]],
|
||||
[[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]],
|
||||
[[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]],
|
||||
[[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]],
|
||||
]
|
||||
|
||||
min_cc = 100
|
||||
max_cc = 130
|
||||
|
||||
epi_type = DataType.f32
|
||||
|
||||
math_instructions_1sm = [
|
||||
MathInstruction(
|
||||
[128, 256, 32],
|
||||
DataType.e4m3, DataType.e4m3, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add)]
|
||||
|
||||
cluster_shapes_1sm = [
|
||||
[1,2,1], [2,1,1], [1,1,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.StreamK,
|
||||
]
|
||||
|
||||
# 1xSM MMA kernels
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_1sm:
|
||||
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
math_inst.instruction_shape[0] * multiplier_1sm[0],
|
||||
math_inst.instruction_shape[1] * multiplier_1sm[1],
|
||||
math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
|
||||
|
||||
data_types = [
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.f16,
|
||||
"d_type" : DataType.e4m3,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
}]
|
||||
|
||||
# Set alignment d based on Destination format.
|
||||
for layout in layouts:
|
||||
layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]]
|
||||
|
||||
for data_type in data_types:
|
||||
if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\
|
||||
( data_type["d_type"] == DataType.e5m2 ):
|
||||
continue
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]],
|
||||
tile_schedulers=tile_schedulers)
|
||||
|
||||
# 2xSM MMA kernels
|
||||
math_instructions_2sm = [
|
||||
MathInstruction(
|
||||
[256, 256, 32],
|
||||
DataType.e4m3, DataType.e4m3, DataType.f32,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add),
|
||||
]
|
||||
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
math_inst.instruction_shape[0] * multiplier_2sm[0],
|
||||
math_inst.instruction_shape[1] * multiplier_2sm[1],
|
||||
math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
|
||||
|
||||
data_types = [
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.f16,
|
||||
"d_type" : DataType.e4m3,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
}]
|
||||
|
||||
# Set alignment d based on Destination format.
|
||||
for layout in layouts:
|
||||
layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]]
|
||||
|
||||
for data_type in data_types:
|
||||
if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\
|
||||
( data_type["d_type"] == DataType.e5m2 ):
|
||||
continue
|
||||
|
||||
if math_inst.instruction_shape[0] == 128:
|
||||
epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm
|
||||
else:
|
||||
epi_schedule = EpilogueScheduleType.ScheduleAuto
|
||||
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers)
|
||||
|
||||
def GenerateSM100(manifest, cuda_version):
|
||||
#
|
||||
@@ -8570,13 +8929,19 @@ def GenerateSM100(manifest, cuda_version):
|
||||
|
||||
GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version)
|
||||
# grouped GEMM
|
||||
GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x)
|
||||
GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x)
|
||||
GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x)
|
||||
GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x)
|
||||
|
||||
GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version)
|
||||
|
||||
# StreamK is included in regular generation
|
||||
GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version)
|
||||
#
|
||||
# Block Scaled Gemm
|
||||
#
|
||||
GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version)
|
||||
GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version)
|
||||
GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
@@ -8955,8 +9320,8 @@ def GenerateSM90(manifest, cuda_version):
|
||||
GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_1684(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x)
|
||||
GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x)
|
||||
GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x)
|
||||
GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x)
|
||||
GenerateSM90_TensorOp_1684_complex(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version)
|
||||
GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version)
|
||||
|
||||
Reference in New Issue
Block a user