mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
Support for Group GEMM in CUTLASS Profiler for Geforce and Spark (#3092)
Co-authored-by: dePaul Miller <23461061+depaulmillz@users.noreply.github.com>
This commit is contained in:
@@ -11176,11 +11176,13 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version,
|
||||
conv_kind = ConvKind.Fprop,
|
||||
log_indent_level = log_indent_level)
|
||||
|
||||
def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version):
|
||||
def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x):
|
||||
# SM120 MMA with mixed F4/F6/F8 inputs + block scale
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
|
||||
return
|
||||
|
||||
grouped = is_grouped(gemm_kind)
|
||||
|
||||
layouts = [
|
||||
[[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]]
|
||||
]
|
||||
@@ -11206,16 +11208,17 @@ def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
acc_types = [ DataType.f32 ]
|
||||
|
||||
def is_pingpong(kernel_schedule):
|
||||
if kernel_schedule == KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120:
|
||||
if kernel_schedule == KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120 or \
|
||||
kernel_schedule == KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def tile_schedulers(sfdtype, kernel_schedule):
|
||||
# Pingpong kernel schedule doesn't support stream-K.
|
||||
# Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void,
|
||||
# the epilogue is the traditional linear combination, for which we already have tests with stream-K
|
||||
if is_pingpong(kernel_schedule):
|
||||
if grouped or is_pingpong(kernel_schedule):
|
||||
return [TileSchedulerType.Default]
|
||||
elif sfdtype["type"] == DataType.void:
|
||||
return [TileSchedulerType.Default]
|
||||
@@ -11226,12 +11229,12 @@ def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
max_cc = 121
|
||||
|
||||
epi_type = DataType.f32
|
||||
|
||||
|
||||
math_instructions = []
|
||||
|
||||
kernel_schedules = [
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120,
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120
|
||||
to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120, grouped),
|
||||
to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120, grouped)
|
||||
]
|
||||
|
||||
for instr_size, a_type, b_type, acc_type in product(instruction_sizes, ab_types, ab_types, acc_types):
|
||||
@@ -11299,16 +11302,18 @@ def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
|
||||
for data_type, kernel_schedule in product(data_types, kernel_schedules):
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[kernel_schedule, EpilogueScheduleType.ScheduleAuto]],
|
||||
[[kernel_schedule, EpilogueScheduleType.ScheduleAuto]],
|
||||
tile_schedulers = tile_schedulers(data_type["sfd_type"], kernel_schedule),
|
||||
gemm_kind = GemmKind.BlockScaledUniversal3x
|
||||
gemm_kind = gemm_kind
|
||||
)
|
||||
|
||||
def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version):
|
||||
def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x):
|
||||
# SM120 MMA with with F4 + block scale
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
|
||||
return
|
||||
|
||||
grouped = is_grouped(gemm_kind)
|
||||
|
||||
# layouts for ABC and their alignments.
|
||||
layouts = [
|
||||
[[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]]
|
||||
@@ -11344,11 +11349,12 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
|
||||
def is_pingpong(kernel_schedule):
|
||||
if kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120 or \
|
||||
kernel_schedule == KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120:
|
||||
kernel_schedule == KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120 or \
|
||||
kernel_schedule == KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_nvf4(kernel_schedule):
|
||||
if kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120 or \
|
||||
kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120:
|
||||
@@ -11360,7 +11366,7 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
# Pingpong kernel schedule doesn't support stream-K.
|
||||
# Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void,
|
||||
# the epilogue is the traditional linear combination, for which we already have tests with stream-K
|
||||
if is_pingpong(kernel_schedule):
|
||||
if grouped or is_pingpong(kernel_schedule):
|
||||
return [TileSchedulerType.Default]
|
||||
elif sfdtype["type"] == DataType.void:
|
||||
return [TileSchedulerType.Default]
|
||||
@@ -11374,12 +11380,12 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
|
||||
math_instructions = []
|
||||
|
||||
kernel_schedules = [
|
||||
KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120,
|
||||
KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120,
|
||||
KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120,
|
||||
KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120
|
||||
]
|
||||
kernel_schedules = list(set([
|
||||
to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120, grouped),
|
||||
to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120, grouped),
|
||||
to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120, grouped),
|
||||
to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120, grouped)
|
||||
])) # ensure no duplicates
|
||||
|
||||
for instr_size, a_type, b_type, acc_type, sf_type in product(instruction_sizes, ab_types, ab_types, acc_types, sf_types):
|
||||
math_instructions.append(
|
||||
@@ -11394,12 +11400,16 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
for math_inst in math_instructions:
|
||||
for kernel_schedule in kernel_schedules:
|
||||
tile_descriptions = []
|
||||
is_grouped_schedule = grouped
|
||||
tile_sizes = tile_sizes_pingpong if is_pingpong(kernel_schedule) else tile_sizes_cooperative
|
||||
for tile_size in tile_sizes:
|
||||
# nvf4 kernel only supports ue4m3 SF
|
||||
# mxf4 kernel only supports ue8m0 SF
|
||||
if (math_inst.element_scale_factor == DataType.ue4m3 and is_nvf4(kernel_schedule)) or \
|
||||
(math_inst.element_scale_factor == DataType.ue8m0 and not is_nvf4(kernel_schedule)):
|
||||
# grouped schedules only support ue8m0 (MXF4); NVF4 (ue4m3) grouped requires
|
||||
# NVF4-specific PtrArray schedule tags not yet available
|
||||
if (is_grouped_schedule and math_inst.element_scale_factor == DataType.ue8m0) or \
|
||||
(not is_grouped_schedule and math_inst.element_scale_factor == DataType.ue4m3 and is_nvf4(kernel_schedule)) or \
|
||||
(not is_grouped_schedule and math_inst.element_scale_factor == DataType.ue8m0 and not is_nvf4(kernel_schedule)):
|
||||
tile_descriptions.append(
|
||||
TileDescription(tile_size, 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
|
||||
|
||||
@@ -11482,10 +11492,10 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
|
||||
for data_type in data_types:
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[kernel_schedule, EpilogueScheduleType.ScheduleAuto]],
|
||||
[[kernel_schedule, EpilogueScheduleType.ScheduleAuto]],
|
||||
tile_schedulers = tile_schedulers(data_type["sfd_type"], kernel_schedule),
|
||||
gemm_kind = GemmKind.BlockScaledUniversal3x
|
||||
)
|
||||
gemm_kind = gemm_kind
|
||||
)
|
||||
|
||||
def GenerateSM120_Sparse_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version):
|
||||
# SM120 MMA with mixed F4/F6/F8 inputs + block scale
|
||||
@@ -12048,6 +12058,11 @@ def GenerateSM120(manifest, cuda_version):
|
||||
GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version)
|
||||
GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version)
|
||||
#
|
||||
# Grouped Block Scaled Gemm
|
||||
#
|
||||
GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x)
|
||||
GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x)
|
||||
#
|
||||
# Sparse Block Scaled Gemm
|
||||
#
|
||||
GenerateSM120_Sparse_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version)
|
||||
|
||||
@@ -615,6 +615,9 @@ class KernelScheduleType(enum.Enum):
|
||||
BlockwiseTmaWarpSpecializedCooperativeSm120 = enum_auto()
|
||||
BlockwiseTmaWarpSpecializedPingpongSm120 = enum_auto()
|
||||
|
||||
PtrArrayTmaWarpSpecializedCooperativeBlockScaledSm120 = enum_auto()
|
||||
PtrArrayTmaWarpSpecializedPingpongBlockScaledSm120 = enum_auto()
|
||||
|
||||
KernelScheduleTag = {
|
||||
KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto',
|
||||
KernelScheduleType.Multistage: 'cutlass::gemm::KernelMultistage',
|
||||
@@ -730,6 +733,8 @@ KernelScheduleTag = {
|
||||
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120',
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeBlockScaledSm120: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeBlockScaledSm120<3>',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongBlockScaledSm120: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongBlockScaledSm120<3>',
|
||||
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecializedSm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedMxf8f6f4Sm120',
|
||||
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecializedAcc2x4Sm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedMxf8f6f4Acc2x4Sm120',
|
||||
KernelScheduleType.SparseNvf4TmaWarpSpecializedSm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedNvf4Sm120',
|
||||
@@ -1040,6 +1045,13 @@ def to_grouped_schedule(schedule, grouped):
|
||||
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch,
|
||||
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch,
|
||||
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch,
|
||||
# SM120
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
|
||||
KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
|
||||
KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
|
||||
}
|
||||
|
||||
return group_schedule_map[schedule]
|
||||
|
||||
Reference in New Issue
Block a user