Support for Group GEMM in CUTLASS Profiler for Geforce and Spark (#3092)

Co-authored-by: dePaul Miller <23461061+depaulmillz@users.noreply.github.com>
This commit is contained in:
dePaul Miller
2026-03-06 17:36:29 -08:00
committed by GitHub
parent e5fcd125a5
commit 73c59c055c
8 changed files with 88 additions and 36 deletions

View File

@@ -11176,11 +11176,13 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version,
conv_kind = ConvKind.Fprop,
log_indent_level = log_indent_level)
def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version):
def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x):
# SM120 MMA with mixed F4/F6/F8 inputs + block scale
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
return
grouped = is_grouped(gemm_kind)
layouts = [
[[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]]
]
@@ -11206,16 +11208,17 @@ def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
acc_types = [ DataType.f32 ]
def is_pingpong(kernel_schedule):
if kernel_schedule == KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120:
if kernel_schedule == KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120 or \
kernel_schedule == KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong:
return True
else:
return False
def tile_schedulers(sfdtype, kernel_schedule):
# Pingpong kernel schedule doesn't support stream-K.
# Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void,
# the epilogue is the traditional linear combination, for which we already have tests with stream-K
if is_pingpong(kernel_schedule):
if grouped or is_pingpong(kernel_schedule):
return [TileSchedulerType.Default]
elif sfdtype["type"] == DataType.void:
return [TileSchedulerType.Default]
@@ -11226,12 +11229,12 @@ def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
max_cc = 121
epi_type = DataType.f32
math_instructions = []
kernel_schedules = [
KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120,
KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120
to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120, grouped),
to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120, grouped)
]
for instr_size, a_type, b_type, acc_type in product(instruction_sizes, ab_types, ab_types, acc_types):
@@ -11299,16 +11302,18 @@ def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
for data_type, kernel_schedule in product(data_types, kernel_schedules):
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[kernel_schedule, EpilogueScheduleType.ScheduleAuto]],
[[kernel_schedule, EpilogueScheduleType.ScheduleAuto]],
tile_schedulers = tile_schedulers(data_type["sfd_type"], kernel_schedule),
gemm_kind = GemmKind.BlockScaledUniversal3x
gemm_kind = gemm_kind
)
def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version):
def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x):
# SM120 MMA with with F4 + block scale
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
return
grouped = is_grouped(gemm_kind)
# layouts for ABC and their alignments.
layouts = [
[[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]]
@@ -11344,11 +11349,12 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
def is_pingpong(kernel_schedule):
if kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120 or \
kernel_schedule == KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120:
kernel_schedule == KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120 or \
kernel_schedule == KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong:
return True
else:
return False
def is_nvf4(kernel_schedule):
if kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120 or \
kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120:
@@ -11360,7 +11366,7 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
# Pingpong kernel schedule doesn't support stream-K.
# Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void,
# the epilogue is the traditional linear combination, for which we already have tests with stream-K
if is_pingpong(kernel_schedule):
if grouped or is_pingpong(kernel_schedule):
return [TileSchedulerType.Default]
elif sfdtype["type"] == DataType.void:
return [TileSchedulerType.Default]
@@ -11374,12 +11380,12 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
math_instructions = []
kernel_schedules = [
KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120,
KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120,
KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120,
KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120
]
kernel_schedules = list(set([
to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120, grouped),
to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120, grouped),
to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120, grouped),
to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120, grouped)
])) # ensure no duplicates
for instr_size, a_type, b_type, acc_type, sf_type in product(instruction_sizes, ab_types, ab_types, acc_types, sf_types):
math_instructions.append(
@@ -11394,12 +11400,16 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
for math_inst in math_instructions:
for kernel_schedule in kernel_schedules:
tile_descriptions = []
is_grouped_schedule = grouped
tile_sizes = tile_sizes_pingpong if is_pingpong(kernel_schedule) else tile_sizes_cooperative
for tile_size in tile_sizes:
# nvf4 kernel only supports ue4m3 SF
# mxf4 kernel only supports ue8m0 SF
if (math_inst.element_scale_factor == DataType.ue4m3 and is_nvf4(kernel_schedule)) or \
(math_inst.element_scale_factor == DataType.ue8m0 and not is_nvf4(kernel_schedule)):
# grouped schedules only support ue8m0 (MXF4); NVF4 (ue4m3) grouped requires
# NVF4-specific PtrArray schedule tags not yet available
if (is_grouped_schedule and math_inst.element_scale_factor == DataType.ue8m0) or \
(not is_grouped_schedule and math_inst.element_scale_factor == DataType.ue4m3 and is_nvf4(kernel_schedule)) or \
(not is_grouped_schedule and math_inst.element_scale_factor == DataType.ue8m0 and not is_nvf4(kernel_schedule)):
tile_descriptions.append(
TileDescription(tile_size, 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
@@ -11482,10 +11492,10 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
for data_type in data_types:
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[kernel_schedule, EpilogueScheduleType.ScheduleAuto]],
[[kernel_schedule, EpilogueScheduleType.ScheduleAuto]],
tile_schedulers = tile_schedulers(data_type["sfd_type"], kernel_schedule),
gemm_kind = GemmKind.BlockScaledUniversal3x
)
gemm_kind = gemm_kind
)
def GenerateSM120_Sparse_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version):
# SM120 MMA with mixed F4/F6/F8 inputs + block scale
@@ -12048,6 +12058,11 @@ def GenerateSM120(manifest, cuda_version):
GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version)
GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version)
#
# Grouped Block Scaled Gemm
#
GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x)
GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x)
#
# Sparse Block Scaled Gemm
#
GenerateSM120_Sparse_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version)

View File

@@ -615,6 +615,9 @@ class KernelScheduleType(enum.Enum):
BlockwiseTmaWarpSpecializedCooperativeSm120 = enum_auto()
BlockwiseTmaWarpSpecializedPingpongSm120 = enum_auto()
PtrArrayTmaWarpSpecializedCooperativeBlockScaledSm120 = enum_auto()
PtrArrayTmaWarpSpecializedPingpongBlockScaledSm120 = enum_auto()
KernelScheduleTag = {
KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto',
KernelScheduleType.Multistage: 'cutlass::gemm::KernelMultistage',
@@ -730,6 +733,8 @@ KernelScheduleTag = {
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120',
KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120',
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeBlockScaledSm120: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeBlockScaledSm120<3>',
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongBlockScaledSm120: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongBlockScaledSm120<3>',
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecializedSm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedMxf8f6f4Sm120',
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecializedAcc2x4Sm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedMxf8f6f4Acc2x4Sm120',
KernelScheduleType.SparseNvf4TmaWarpSpecializedSm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedNvf4Sm120',
@@ -1040,6 +1045,13 @@ def to_grouped_schedule(schedule, grouped):
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch,
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch,
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch,
# SM120
KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
}
return group_schedule_map[schedule]