Support for Group GEMM in CUTLASS Profiler for Geforce and Spark (#3092)

Co-authored-by: dePaul Miller <23461061+depaulmillz@users.noreply.github.com>
This commit is contained in:
dePaul Miller
2026-03-06 17:36:29 -08:00
committed by GitHub
parent e5fcd125a5
commit 73c59c055c
8 changed files with 88 additions and 36 deletions

View File

@@ -46,7 +46,7 @@ set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=50 --iterations=0)
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
set(TEST_RANDOM_PERF_LARGE_GROUP --groups=50 --iterations=10) # Random problem sizes
if (CUTLASS_NVCC_ARCHS MATCHES "120a|121a")
if (CUTLASS_NVCC_ARCHS MATCHES "120a|120f|121a")
cutlass_example_add_executable(
79a_blackwell_geforce_nvfp4_bf16_gemm
79a_blackwell_geforce_nvfp4_bf16_gemm.cu

View File

@@ -27,7 +27,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES "120a|121a")
if (CUTLASS_NVCC_ARCHS MATCHES "120a|120f|121a")
cutlass_example_add_executable(
80a_blackwell_geforce_mxfp8_bf16_sparse_gemm
80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu

View File

@@ -28,7 +28,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if (CUTLASS_NVCC_ARCHS MATCHES "120a|121a")
if (CUTLASS_NVCC_ARCHS MATCHES "120a|120f|121a")
cutlass_example_add_executable(
87a_blackwell_geforce_fp8_bf16_gemm_blockwise
87a_blackwell_geforce_fp8_bf16_gemm_blockwise.cu

View File

@@ -186,13 +186,13 @@ struct CollectiveBuilder<
// Basic storage block for new Scaling Factor Layouts
using mnBasicBlockShape = Shape<_32,_4>;
using mnBasicBlockStride = Stride<_16,_4>;
using kBasicBlockShape = Shape<Int<SFVectorSize>, Int<MMA_NSF>>;
using kBasicBlockShape = Shape<Int<(int)SFVectorSize>, Int<MMA_NSF>>;
using kBasicBlockStride = Stride<_0, _1>;
using sSFA_shapeM = decltype(prepend(size<0>(TileShape_MNK{}) / Blk_MN{}, mnBasicBlockShape{}));
using sSF_strideMN = decltype(prepend( Blk_Elems{}, mnBasicBlockStride{}));
using sSFA_strideM = sSF_strideMN;
using sSF_shapeK = decltype(prepend(make_shape( Blk_SF{}/Int<MMA_NSF>{}, size<2>(TileShape_MNK{}) / Int<SFVectorSize>{} / Blk_SF{}), kBasicBlockShape{}));
using sSF_shapeK = decltype(prepend(make_shape( Blk_SF{}/Int<MMA_NSF>{}, size<2>(TileShape_MNK{}) / Int<(int)SFVectorSize>{} / Blk_SF{}), kBasicBlockShape{}));
using sSFA_strideK = decltype(prepend(make_stride( Int<MMA_NSF>{}, size<0>(TileShape_MNK{}) / Blk_MN{} * Blk_Elems{}), kBasicBlockStride{}));
using sSFA_shape = decltype(make_shape( sSFA_shapeM{}, sSF_shapeK{}));
@@ -209,11 +209,6 @@ struct CollectiveBuilder<
using SmemLayoutAtomsA = decltype(cute::make_tuple(SmemLayoutAtomA{}, SmemLayoutAtomSFA{}));
using SmemLayoutAtomsB = decltype(cute::make_tuple(SmemLayoutAtomB{}, SmemLayoutAtomSFB{}));
static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockscaled<
detail::sm120_smem_capacity_bytes, SmemAllocTypeA, SmemAllocTypeB, TileShape_MNK, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{});
static constexpr uint32_t SchedulerPipelineStageCount = 3;
using StrideA = cutlass::gemm::TagToStrideA_t<GmemLayoutATag>;
using StrideB = cutlass::gemm::TagToStrideB_t<GmemLayoutBTag>;
using InternalStrideA = cute::remove_pointer_t<StrideA>;
@@ -232,6 +227,34 @@ struct CollectiveBuilder<
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedPingpong, BuilderScheduleTag>,
"Invalid builder schedule tag for grouped GEMM");
static constexpr uint32_t SchedulerPipelineStageCount = 3;
static constexpr int CLCResponseSize = sizeof(typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100<Shape<_1,_1,_1>,1>::CLCResponse{});
static constexpr auto SchedulerPipelineStorage = IsGroupedGemmKernel ? sizeof(cutlass::PipelineDetail::PipelineAsyncSharedStorage<8>)
: sizeof(typename cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount, Shape<_1,_1,_1>>::SharedStorage);
static constexpr auto CLCResponseStorage = IsGroupedGemmKernel ? 0 : (SchedulerPipelineStageCount *
CLCResponseSize);
static constexpr auto TensorMapStorage =
IsGroupedGemmKernel ? sizeof(cute::TmaDescriptor) * 2 /* We have two tensormaps smem */ :
0;
// TensorMapReady pipeline storage (specific to grouped/array kernels)
static constexpr auto TensorMapReadyPipelineStorage =
IsGroupedGemmKernel ? sizeof(typename cutlass::PipelineAsync<SchedulerPipelineStageCount>::SharedStorage) :
0;
static constexpr int ReducedSmemCapacityBytes = detail::sm120_smem_capacity_bytes -
SchedulerPipelineStorage -
TensorMapStorage -
TensorMapReadyPipelineStorage -
CLCResponseStorage;
static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockscaled<
ReducedSmemCapacityBytes, SmemAllocTypeA, SmemAllocTypeB, TileShape_MNK, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{});
using KernelSchedule = cute::conditional_t<IsGroupedGemmKernel,
// PtrArray
cute::conditional_t<IsCooperative,

View File

@@ -760,6 +760,8 @@ select_instr() {
(SfVectorSize == 32 && cute::is_same_v<KernelScheduleAuto, BuilderScheduleTag>)
|| (SfVectorSize == 32 && cute::is_base_of_v<KernelScheduleBlockScaledGemmSm100, BuilderScheduleTag>)
|| (SfVectorSize == 32 && cute::is_base_of_v<KernelTmaWarpSpecializedPingpong, BuilderScheduleTag>)
|| (SfVectorSize == 32 && cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, BuilderScheduleTag>)
|| (SfVectorSize == 32 && cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedPingpong, BuilderScheduleTag>)
|| (SfVectorSize == 32 && cute::is_base_of_v<KernelTmaWarpSpecializedCooperative, BuilderScheduleTag>)
|| (SfVectorSize == 32 && cute::is_base_of_v<KernelSchedulePtrArrayBlockScaledGemmSm100, BuilderScheduleTag>)
|| (SfVectorSize == 64 && cute::is_base_of_v<KernelScheduleBlockScaledSparseGemmSm100, BuilderScheduleTag>

View File

@@ -11176,11 +11176,13 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version,
conv_kind = ConvKind.Fprop,
log_indent_level = log_indent_level)
def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version):
def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x):
# SM120 MMA with mixed F4/F6/F8 inputs + block scale
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
return
grouped = is_grouped(gemm_kind)
layouts = [
[[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]]
]
@@ -11206,16 +11208,17 @@ def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
acc_types = [ DataType.f32 ]
def is_pingpong(kernel_schedule):
if kernel_schedule == KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120:
if kernel_schedule == KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120 or \
kernel_schedule == KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong:
return True
else:
return False
def tile_schedulers(sfdtype, kernel_schedule):
# Pingpong kernel schedule doesn't support stream-K.
# Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void,
# the epilogue is the traditional linear combination, for which we already have tests with stream-K
if is_pingpong(kernel_schedule):
if grouped or is_pingpong(kernel_schedule):
return [TileSchedulerType.Default]
elif sfdtype["type"] == DataType.void:
return [TileSchedulerType.Default]
@@ -11226,12 +11229,12 @@ def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
max_cc = 121
epi_type = DataType.f32
math_instructions = []
kernel_schedules = [
KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120,
KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120
to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120, grouped),
to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120, grouped)
]
for instr_size, a_type, b_type, acc_type in product(instruction_sizes, ab_types, ab_types, acc_types):
@@ -11299,16 +11302,18 @@ def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
for data_type, kernel_schedule in product(data_types, kernel_schedules):
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[kernel_schedule, EpilogueScheduleType.ScheduleAuto]],
[[kernel_schedule, EpilogueScheduleType.ScheduleAuto]],
tile_schedulers = tile_schedulers(data_type["sfd_type"], kernel_schedule),
gemm_kind = GemmKind.BlockScaledUniversal3x
gemm_kind = gemm_kind
)
def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version):
def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x):
# SM120 MMA with with F4 + block scale
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
return
grouped = is_grouped(gemm_kind)
# layouts for ABC and their alignments.
layouts = [
[[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]]
@@ -11344,11 +11349,12 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
def is_pingpong(kernel_schedule):
if kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120 or \
kernel_schedule == KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120:
kernel_schedule == KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120 or \
kernel_schedule == KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong:
return True
else:
return False
def is_nvf4(kernel_schedule):
if kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120 or \
kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120:
@@ -11360,7 +11366,7 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
# Pingpong kernel schedule doesn't support stream-K.
# Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void,
# the epilogue is the traditional linear combination, for which we already have tests with stream-K
if is_pingpong(kernel_schedule):
if grouped or is_pingpong(kernel_schedule):
return [TileSchedulerType.Default]
elif sfdtype["type"] == DataType.void:
return [TileSchedulerType.Default]
@@ -11374,12 +11380,12 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
math_instructions = []
kernel_schedules = [
KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120,
KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120,
KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120,
KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120
]
kernel_schedules = list(set([
to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120, grouped),
to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120, grouped),
to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120, grouped),
to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120, grouped)
])) # ensure no duplicates
for instr_size, a_type, b_type, acc_type, sf_type in product(instruction_sizes, ab_types, ab_types, acc_types, sf_types):
math_instructions.append(
@@ -11394,12 +11400,16 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
for math_inst in math_instructions:
for kernel_schedule in kernel_schedules:
tile_descriptions = []
is_grouped_schedule = grouped
tile_sizes = tile_sizes_pingpong if is_pingpong(kernel_schedule) else tile_sizes_cooperative
for tile_size in tile_sizes:
# nvf4 kernel only supports ue4m3 SF
# mxf4 kernel only supports ue8m0 SF
if (math_inst.element_scale_factor == DataType.ue4m3 and is_nvf4(kernel_schedule)) or \
(math_inst.element_scale_factor == DataType.ue8m0 and not is_nvf4(kernel_schedule)):
# grouped schedules only support ue8m0 (MXF4); NVF4 (ue4m3) grouped requires
# NVF4-specific PtrArray schedule tags not yet available
if (is_grouped_schedule and math_inst.element_scale_factor == DataType.ue8m0) or \
(not is_grouped_schedule and math_inst.element_scale_factor == DataType.ue4m3 and is_nvf4(kernel_schedule)) or \
(not is_grouped_schedule and math_inst.element_scale_factor == DataType.ue8m0 and not is_nvf4(kernel_schedule)):
tile_descriptions.append(
TileDescription(tile_size, 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
@@ -11482,10 +11492,10 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
for data_type in data_types:
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[kernel_schedule, EpilogueScheduleType.ScheduleAuto]],
[[kernel_schedule, EpilogueScheduleType.ScheduleAuto]],
tile_schedulers = tile_schedulers(data_type["sfd_type"], kernel_schedule),
gemm_kind = GemmKind.BlockScaledUniversal3x
)
gemm_kind = gemm_kind
)
def GenerateSM120_Sparse_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version):
# SM120 MMA with mixed F4/F6/F8 inputs + block scale
@@ -12048,6 +12058,11 @@ def GenerateSM120(manifest, cuda_version):
GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version)
GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version)
#
# Grouped Block Scaled Gemm
#
GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x)
GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x)
#
# Sparse Block Scaled Gemm
#
GenerateSM120_Sparse_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version)

View File

@@ -615,6 +615,9 @@ class KernelScheduleType(enum.Enum):
BlockwiseTmaWarpSpecializedCooperativeSm120 = enum_auto()
BlockwiseTmaWarpSpecializedPingpongSm120 = enum_auto()
PtrArrayTmaWarpSpecializedCooperativeBlockScaledSm120 = enum_auto()
PtrArrayTmaWarpSpecializedPingpongBlockScaledSm120 = enum_auto()
KernelScheduleTag = {
KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto',
KernelScheduleType.Multistage: 'cutlass::gemm::KernelMultistage',
@@ -730,6 +733,8 @@ KernelScheduleTag = {
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120',
KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120',
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeBlockScaledSm120: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeBlockScaledSm120<3>',
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongBlockScaledSm120: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongBlockScaledSm120<3>',
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecializedSm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedMxf8f6f4Sm120',
KernelScheduleType.SparseMxf8f6f4TmaWarpSpecializedAcc2x4Sm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedMxf8f6f4Acc2x4Sm120',
KernelScheduleType.SparseNvf4TmaWarpSpecializedSm120: 'cutlass::gemm::KernelSparseTmaWarpSpecializedNvf4Sm120',
@@ -1040,6 +1045,13 @@ def to_grouped_schedule(schedule, grouped):
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch,
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch,
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch,
# SM120
KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
}
return group_schedule_map[schedule]

View File

@@ -563,9 +563,9 @@ public:
}
operator_args.mainloop.ptr_SFA =
static_cast<const typename Operator::GemmKernel::ElementSF**>(arguments->SFA);
static_cast<const typename CollectiveMainloop::ElementSF**>(arguments->SFA);
operator_args.mainloop.ptr_SFB =
static_cast<const typename Operator::GemmKernel::ElementSF**>(arguments->SFB);
static_cast<const typename CollectiveMainloop::ElementSF**>(arguments->SFB);
operator_args.mainloop.layout_SFA =
static_cast<typename CollectiveMainloop::InternalLayoutSFA*>(this->layout_SFA_device.data());