update 3.8 v2 (#2112)

* update 3.8 v2

* update 3.8

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
Yujia Zhai
2025-02-19 19:03:14 -08:00
committed by GitHub
parent e9627ce55b
commit b84e9802d8
166 changed files with 3986 additions and 4037 deletions

View File

@@ -64,17 +64,15 @@ class GemmOperation:
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto,
tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False
, ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None
):
tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False,
ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None):
kinds_3x = {
GemmKind.Universal3x,
GemmKind.SparseUniversal3x,
GemmKind.BlockScaledUniversal3x,
GemmKind.GroupedGemmUniversal3x,
GemmKind.GroupedUniversal3x,
GemmKind.GroupedBlockScaledUniversal3x,
}
self.is_3x = gemm_kind in kinds_3x
self.prefix = "3x" if self.is_3x else ""
@@ -87,13 +85,11 @@ class GemmOperation:
self.C = C
self.D = D
if self.gemm_kind == GemmKind.BlockScaledUniversal3x:
if is_block_scaled(gemm_kind):
self.ScaleFactorA = ScaleFactorA
self.ScaleFactorB = ScaleFactorB
self.ScaleFactorD = ScaleFactorD["tensor"]
self.ScaleFactorVectorSize = ScaleFactorD["vector_size"]
if self.D == None:
self.D = self.C
@@ -239,13 +235,13 @@ class GemmOperation:
element_c = DataTypeNames[self.C.element],
element_d = DataTypeNames[self.D.element],
core_name = self.core_name())
if self.gemm_kind == GemmKind.BlockScaledUniversal3x:
if is_block_scaled(self.gemm_kind):
d_type_names = DataTypeNames[self.D.element]
if self.ScaleFactorD.element != DataType.void:
d_type_names = DataTypeNames[self.ScaleFactorD.element] + "x" + d_type_names
extended_name = "{core_name}_{element_sfa}x{element_a}_{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format(
element_sfa = DataTypeNames[self.ScaleFactorA],
element_a = DataTypeNames[self.A.element],
@@ -255,7 +251,7 @@ class GemmOperation:
element_c = DataTypeNames[self.C.element],
element_d = d_type_names,
core_name = self.core_name())
if self.mixed_input_mode != None:
extended_name = extended_name + self.mixed_input_mode_name()
return extended_name
@@ -298,8 +294,8 @@ class GemmOperation:
# Generates a short string representing underlying epilogue schedule type
def epilogue_schedule_name_3x(self):
if self.gemm_kind == GemmKind.BlockScaledUniversal3x:
if is_block_scaled(self.gemm_kind):
if self.ScaleFactorD.element != DataType.void:
return EpilogueScheduleSuffixes[self.epilogue_schedule] + "_epiVs" + str(self.ScaleFactorVectorSize)+ShortLayoutTypeNames[self.ScaleFactorD.layout]
@@ -779,7 +775,7 @@ class EmitGemmUniversal3xInstance:
using ${operation_name}_epilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
${arch}, ${opcode_class_epi},
cute::Shape<cute::_${tile_shape_epi_m}, cute::_${tile_shape_epi_n}, cute::_${tile_shape_epi_k}>,
cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>,
${epi_tile_mn},
${element_accumulator}, ${element_epilogue},
@@ -797,7 +793,7 @@ using ${operation_name}_mainloop =
${element_a}, ${layout_a}, ${align_a},
${element_b}, ${layout_b}, ${align_b},
${element_accumulator},
cute::Shape<cute::_${tile_shape_main_m}, cute::_${tile_shape_main_n}, cute::_${tile_shape_main_k}>,
cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>,
${stages},
${kernel_schedule}
@@ -855,7 +851,7 @@ ${compile_guard_end}
@staticmethod
def pointerize_if_grouped(operation, layout):
return layout if operation.gemm_kind != GemmKind.GroupedGemmUniversal3x else layout + "* "
return layout if not is_grouped(operation.gemm_kind) else layout + "* "
@staticmethod
def problem_shape(operation):
@@ -863,7 +859,7 @@ ${compile_guard_end}
grouped_gemm_shape_type = "cute::Shape<int,int,int>"
grouped_gemm_shape_type = "cutlass::gemm::GroupProblemShape<" + grouped_gemm_shape_type + ">"
return gemm_shape_type if operation.gemm_kind != GemmKind.GroupedGemmUniversal3x else grouped_gemm_shape_type
return gemm_shape_type if not is_grouped(operation.gemm_kind) else grouped_gemm_shape_type
def emit(self, operation):
_LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)")
@@ -874,18 +870,12 @@ ${compile_guard_end}
opcode_class_main = operation.tile_description.math_instruction.opcode_class
opcode_class_epi = opcode_class_main
if opcode_class_main == OpcodeClass.BlockScaledTensorOp:
if operation.epilogue_schedule != EpilogueScheduleType.NoSmemWarpSpecialized:
opcode_class_epi = OpcodeClass.TensorOp
tile_shape = operation.tile_description.tile_shape
instruction_shape = operation.tile_description.math_instruction.instruction_shape
cluster_m = operation.tile_description.cluster_shape[0]
cluster_n = operation.tile_description.cluster_shape[1]
tile_shape_main_m, tile_shape_main_n, tile_shape_main_k = tile_shape
tile_shape_epi_m, tile_shape_epi_n, tile_shape_epi_k = tile_shape
tile_shape_m, tile_shape_n, tile_shape_k = tile_shape
# account for static/dynamic cluster shapes
cta_m = tile_shape[0] // cluster_m if cluster_m > 0 else tile_shape[0]
@@ -902,10 +892,8 @@ ${compile_guard_end}
if opcode_class_main in [OpcodeClass.TensorOp
, OpcodeClass.BlockScaledTensorOp
]:
tile_shape_main_m = instruction_shape[0]
tile_shape_main_n = instruction_shape[1]
tile_shape_epi_m = cta_m
tile_shape_epi_n = cta_n
tile_shape_m = instruction_shape[0]
tile_shape_n = instruction_shape[1]
# stage count set to zero indicates builder automatic stage selection
@@ -930,35 +918,36 @@ ${compile_guard_end}
}
epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values)
if operation.gemm_kind == GemmKind.BlockScaledUniversal3x and operation.ScaleFactorD.element != DataType.void:
if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void:
epilogue_functor = self.emit_block_scale_epilogue_functor(operation)
else:
epilogue_functor = self.epilogue_functor.emit_declaration()
if operation.gemm_kind == GemmKind.BlockScaledUniversal3x and operation.ScaleFactorD.element != DataType.void:
if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void:
epilogue_functor = self.emit_block_scale_epilogue_functor(operation)
#
# Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, e.g. cute::tuple<Element, Transform>, Transform : cute::identity / cute::conjugate.
element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>"
element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>"
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
is_no_smem_epilogue = operation.epilogue_schedule == EpilogueScheduleType.NoSmemWarpSpecialized
if opcode_class_main == OpcodeClass.BlockScaledTensorOp:
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100:
is_no_smem_epilogue = operation.epilogue_schedule in [EpilogueScheduleType.NoSmemWarpSpecialized1Sm, EpilogueScheduleType.NoSmemWarpSpecialized2Sm]
grouped = is_grouped(operation.gemm_kind)
if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped):
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
if not is_no_smem_epilogue:
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm]
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100:
epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)]
if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped):
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
if not is_no_smem_epilogue:
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm]
epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]
element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>'
element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>'
operation_name_str = operation.procedural_name()
layout_a_str = LayoutTag[instance_layout_A]
@@ -1041,12 +1030,9 @@ using {operation_name_str}_LayoutNarrowReordered = decltype(cute::tile_to_shape(
'opcode_class_main': OpcodeClassTag[opcode_class_main],
'opcode_class_epi': OpcodeClassTag[opcode_class_epi],
'arch': "cutlass::arch::Sm%d" % operation.arch,
'tile_shape_epi_m': str(tile_shape_epi_m),
'tile_shape_epi_n': str(tile_shape_epi_n),
'tile_shape_epi_k': str(tile_shape_epi_k),
'tile_shape_main_m': str(tile_shape_main_m),
'tile_shape_main_n': str(tile_shape_main_n),
'tile_shape_main_k': str(tile_shape_main_k),
'tile_shape_m': str(tile_shape_m),
'tile_shape_n': str(tile_shape_n),
'tile_shape_k': str(tile_shape_k),
'cluster_shape_m': 'cute::_' + str(operation.tile_description.cluster_shape[0]) if operation.tile_description.cluster_shape[0] > 0 else "int",
'cluster_shape_n': 'cute::_' + str(operation.tile_description.cluster_shape[1]) if operation.tile_description.cluster_shape[1] > 0 else "int",
'cluster_shape_k': 'cute::_' + str(operation.tile_description.cluster_shape[2]) if operation.tile_description.cluster_shape[2] > 0 else "int",
@@ -1396,7 +1382,8 @@ class EmitGemmConfigurationLibrary:
GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance,
GemmKind.Grouped: EmitGemmGroupedInstance,
GemmKind.GroupedGemmUniversal3x: EmitGemmUniversal3xInstance,
GemmKind.GroupedUniversal3x: EmitGemmUniversal3xInstance,
GemmKind.GroupedBlockScaledUniversal3x: EmitGemmUniversal3xInstance,
}
self.gemm_kind_wrappers = {
@@ -1409,7 +1396,8 @@ class EmitGemmConfigurationLibrary:
GemmKind.PlanarComplex: 'GemmPlanarComplexOperation',
GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation',
GemmKind.Grouped: 'GemmGroupedOperation',
GemmKind.GroupedGemmUniversal3x: 'GroupedGemmUniversal3xOperation'
GemmKind.GroupedUniversal3x: 'GroupedGemmUniversal3xOperation',
GemmKind.GroupedBlockScaledUniversal3x: 'GroupedBlockScaledGemmUniversal3xOperation',
}
self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"