mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-04 21:51:18 +00:00
update 3.8 v2 (#2112)
* update 3.8 v2 * update 3.8 --------- Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
@@ -64,17 +64,15 @@ class GemmOperation:
|
||||
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
|
||||
kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto,
|
||||
tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False
|
||||
|
||||
, ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None
|
||||
|
||||
):
|
||||
tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False,
|
||||
ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None):
|
||||
|
||||
kinds_3x = {
|
||||
GemmKind.Universal3x,
|
||||
GemmKind.SparseUniversal3x,
|
||||
GemmKind.BlockScaledUniversal3x,
|
||||
GemmKind.GroupedGemmUniversal3x,
|
||||
GemmKind.GroupedUniversal3x,
|
||||
GemmKind.GroupedBlockScaledUniversal3x,
|
||||
}
|
||||
self.is_3x = gemm_kind in kinds_3x
|
||||
self.prefix = "3x" if self.is_3x else ""
|
||||
@@ -87,13 +85,11 @@ class GemmOperation:
|
||||
self.C = C
|
||||
self.D = D
|
||||
|
||||
|
||||
if self.gemm_kind == GemmKind.BlockScaledUniversal3x:
|
||||
if is_block_scaled(gemm_kind):
|
||||
self.ScaleFactorA = ScaleFactorA
|
||||
self.ScaleFactorB = ScaleFactorB
|
||||
self.ScaleFactorD = ScaleFactorD["tensor"]
|
||||
self.ScaleFactorVectorSize = ScaleFactorD["vector_size"]
|
||||
|
||||
|
||||
if self.D == None:
|
||||
self.D = self.C
|
||||
@@ -239,13 +235,13 @@ class GemmOperation:
|
||||
element_c = DataTypeNames[self.C.element],
|
||||
element_d = DataTypeNames[self.D.element],
|
||||
core_name = self.core_name())
|
||||
|
||||
if self.gemm_kind == GemmKind.BlockScaledUniversal3x:
|
||||
|
||||
if is_block_scaled(self.gemm_kind):
|
||||
d_type_names = DataTypeNames[self.D.element]
|
||||
|
||||
|
||||
if self.ScaleFactorD.element != DataType.void:
|
||||
d_type_names = DataTypeNames[self.ScaleFactorD.element] + "x" + d_type_names
|
||||
|
||||
|
||||
extended_name = "{core_name}_{element_sfa}x{element_a}_{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format(
|
||||
element_sfa = DataTypeNames[self.ScaleFactorA],
|
||||
element_a = DataTypeNames[self.A.element],
|
||||
@@ -255,7 +251,7 @@ class GemmOperation:
|
||||
element_c = DataTypeNames[self.C.element],
|
||||
element_d = d_type_names,
|
||||
core_name = self.core_name())
|
||||
|
||||
|
||||
if self.mixed_input_mode != None:
|
||||
extended_name = extended_name + self.mixed_input_mode_name()
|
||||
return extended_name
|
||||
@@ -298,8 +294,8 @@ class GemmOperation:
|
||||
|
||||
# Generates a short string representing underlying epilogue schedule type
|
||||
def epilogue_schedule_name_3x(self):
|
||||
|
||||
if self.gemm_kind == GemmKind.BlockScaledUniversal3x:
|
||||
|
||||
if is_block_scaled(self.gemm_kind):
|
||||
if self.ScaleFactorD.element != DataType.void:
|
||||
return EpilogueScheduleSuffixes[self.epilogue_schedule] + "_epiVs" + str(self.ScaleFactorVectorSize)+ShortLayoutTypeNames[self.ScaleFactorD.layout]
|
||||
|
||||
@@ -779,7 +775,7 @@ class EmitGemmUniversal3xInstance:
|
||||
using ${operation_name}_epilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
${arch}, ${opcode_class_epi},
|
||||
cute::Shape<cute::_${tile_shape_epi_m}, cute::_${tile_shape_epi_n}, cute::_${tile_shape_epi_k}>,
|
||||
cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
|
||||
cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>,
|
||||
${epi_tile_mn},
|
||||
${element_accumulator}, ${element_epilogue},
|
||||
@@ -797,7 +793,7 @@ using ${operation_name}_mainloop =
|
||||
${element_a}, ${layout_a}, ${align_a},
|
||||
${element_b}, ${layout_b}, ${align_b},
|
||||
${element_accumulator},
|
||||
cute::Shape<cute::_${tile_shape_main_m}, cute::_${tile_shape_main_n}, cute::_${tile_shape_main_k}>,
|
||||
cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
|
||||
cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>,
|
||||
${stages},
|
||||
${kernel_schedule}
|
||||
@@ -855,7 +851,7 @@ ${compile_guard_end}
|
||||
|
||||
@staticmethod
|
||||
def pointerize_if_grouped(operation, layout):
|
||||
return layout if operation.gemm_kind != GemmKind.GroupedGemmUniversal3x else layout + "* "
|
||||
return layout if not is_grouped(operation.gemm_kind) else layout + "* "
|
||||
|
||||
@staticmethod
|
||||
def problem_shape(operation):
|
||||
@@ -863,7 +859,7 @@ ${compile_guard_end}
|
||||
grouped_gemm_shape_type = "cute::Shape<int,int,int>"
|
||||
grouped_gemm_shape_type = "cutlass::gemm::GroupProblemShape<" + grouped_gemm_shape_type + ">"
|
||||
|
||||
return gemm_shape_type if operation.gemm_kind != GemmKind.GroupedGemmUniversal3x else grouped_gemm_shape_type
|
||||
return gemm_shape_type if not is_grouped(operation.gemm_kind) else grouped_gemm_shape_type
|
||||
|
||||
def emit(self, operation):
|
||||
_LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)")
|
||||
@@ -874,18 +870,12 @@ ${compile_guard_end}
|
||||
opcode_class_main = operation.tile_description.math_instruction.opcode_class
|
||||
opcode_class_epi = opcode_class_main
|
||||
|
||||
if opcode_class_main == OpcodeClass.BlockScaledTensorOp:
|
||||
if operation.epilogue_schedule != EpilogueScheduleType.NoSmemWarpSpecialized:
|
||||
opcode_class_epi = OpcodeClass.TensorOp
|
||||
|
||||
|
||||
tile_shape = operation.tile_description.tile_shape
|
||||
instruction_shape = operation.tile_description.math_instruction.instruction_shape
|
||||
cluster_m = operation.tile_description.cluster_shape[0]
|
||||
cluster_n = operation.tile_description.cluster_shape[1]
|
||||
|
||||
tile_shape_main_m, tile_shape_main_n, tile_shape_main_k = tile_shape
|
||||
tile_shape_epi_m, tile_shape_epi_n, tile_shape_epi_k = tile_shape
|
||||
tile_shape_m, tile_shape_n, tile_shape_k = tile_shape
|
||||
|
||||
# account for static/dynamic cluster shapes
|
||||
cta_m = tile_shape[0] // cluster_m if cluster_m > 0 else tile_shape[0]
|
||||
@@ -902,10 +892,8 @@ ${compile_guard_end}
|
||||
if opcode_class_main in [OpcodeClass.TensorOp
|
||||
, OpcodeClass.BlockScaledTensorOp
|
||||
]:
|
||||
tile_shape_main_m = instruction_shape[0]
|
||||
tile_shape_main_n = instruction_shape[1]
|
||||
tile_shape_epi_m = cta_m
|
||||
tile_shape_epi_n = cta_n
|
||||
tile_shape_m = instruction_shape[0]
|
||||
tile_shape_n = instruction_shape[1]
|
||||
|
||||
|
||||
# stage count set to zero indicates builder automatic stage selection
|
||||
@@ -930,35 +918,36 @@ ${compile_guard_end}
|
||||
}
|
||||
epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values)
|
||||
|
||||
if operation.gemm_kind == GemmKind.BlockScaledUniversal3x and operation.ScaleFactorD.element != DataType.void:
|
||||
if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void:
|
||||
epilogue_functor = self.emit_block_scale_epilogue_functor(operation)
|
||||
|
||||
|
||||
|
||||
else:
|
||||
epilogue_functor = self.epilogue_functor.emit_declaration()
|
||||
|
||||
if operation.gemm_kind == GemmKind.BlockScaledUniversal3x and operation.ScaleFactorD.element != DataType.void:
|
||||
|
||||
if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void:
|
||||
epilogue_functor = self.emit_block_scale_epilogue_functor(operation)
|
||||
|
||||
|
||||
#
|
||||
# Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, e.g. cute::tuple<Element, Transform>, Transform : cute::identity / cute::conjugate.
|
||||
element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>"
|
||||
element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>"
|
||||
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
|
||||
is_no_smem_epilogue = operation.epilogue_schedule == EpilogueScheduleType.NoSmemWarpSpecialized
|
||||
|
||||
if opcode_class_main == OpcodeClass.BlockScaledTensorOp:
|
||||
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100:
|
||||
is_no_smem_epilogue = operation.epilogue_schedule in [EpilogueScheduleType.NoSmemWarpSpecialized1Sm, EpilogueScheduleType.NoSmemWarpSpecialized2Sm]
|
||||
grouped = is_grouped(operation.gemm_kind)
|
||||
if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped):
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm]
|
||||
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)]
|
||||
if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped):
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm]
|
||||
epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]
|
||||
element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>'
|
||||
element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>'
|
||||
|
||||
|
||||
|
||||
operation_name_str = operation.procedural_name()
|
||||
layout_a_str = LayoutTag[instance_layout_A]
|
||||
@@ -1041,12 +1030,9 @@ using {operation_name_str}_LayoutNarrowReordered = decltype(cute::tile_to_shape(
|
||||
'opcode_class_main': OpcodeClassTag[opcode_class_main],
|
||||
'opcode_class_epi': OpcodeClassTag[opcode_class_epi],
|
||||
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
||||
'tile_shape_epi_m': str(tile_shape_epi_m),
|
||||
'tile_shape_epi_n': str(tile_shape_epi_n),
|
||||
'tile_shape_epi_k': str(tile_shape_epi_k),
|
||||
'tile_shape_main_m': str(tile_shape_main_m),
|
||||
'tile_shape_main_n': str(tile_shape_main_n),
|
||||
'tile_shape_main_k': str(tile_shape_main_k),
|
||||
'tile_shape_m': str(tile_shape_m),
|
||||
'tile_shape_n': str(tile_shape_n),
|
||||
'tile_shape_k': str(tile_shape_k),
|
||||
'cluster_shape_m': 'cute::_' + str(operation.tile_description.cluster_shape[0]) if operation.tile_description.cluster_shape[0] > 0 else "int",
|
||||
'cluster_shape_n': 'cute::_' + str(operation.tile_description.cluster_shape[1]) if operation.tile_description.cluster_shape[1] > 0 else "int",
|
||||
'cluster_shape_k': 'cute::_' + str(operation.tile_description.cluster_shape[2]) if operation.tile_description.cluster_shape[2] > 0 else "int",
|
||||
@@ -1396,7 +1382,8 @@ class EmitGemmConfigurationLibrary:
|
||||
GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
|
||||
GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance,
|
||||
GemmKind.Grouped: EmitGemmGroupedInstance,
|
||||
GemmKind.GroupedGemmUniversal3x: EmitGemmUniversal3xInstance,
|
||||
GemmKind.GroupedUniversal3x: EmitGemmUniversal3xInstance,
|
||||
GemmKind.GroupedBlockScaledUniversal3x: EmitGemmUniversal3xInstance,
|
||||
}
|
||||
|
||||
self.gemm_kind_wrappers = {
|
||||
@@ -1409,7 +1396,8 @@ class EmitGemmConfigurationLibrary:
|
||||
GemmKind.PlanarComplex: 'GemmPlanarComplexOperation',
|
||||
GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation',
|
||||
GemmKind.Grouped: 'GemmGroupedOperation',
|
||||
GemmKind.GroupedGemmUniversal3x: 'GroupedGemmUniversal3xOperation'
|
||||
GemmKind.GroupedUniversal3x: 'GroupedGemmUniversal3xOperation',
|
||||
GemmKind.GroupedBlockScaledUniversal3x: 'GroupedBlockScaledGemmUniversal3xOperation',
|
||||
}
|
||||
|
||||
self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"
|
||||
|
||||
Reference in New Issue
Block a user