CUTLASS 3.4.0 (#1286)

* CUTLASS 3.4.0

* Update CHANGELOG.md

---------

Co-authored-by: Pradeep Ramani <prramani@nvidia.com>
This commit is contained in:
Pradeep Ramani
2023-12-29 12:21:31 -08:00
committed by GitHub
parent b7508e3379
commit 8236f30675
211 changed files with 11409 additions and 2763 deletions

View File

@@ -61,7 +61,8 @@ class GemmOperation:
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto,
tile_scheduler = TileSchedulerType.Default, extra_args = None):
tile_scheduler = TileSchedulerType.Default
):
kinds_3x = {
GemmKind.Universal3x,
@@ -88,6 +89,10 @@ class GemmOperation:
self.epilogue_schedule = epilogue_schedule
self.element_epilogue = element_epilogue
self.epilogue_functor = epilogue_functor
if self.is_3x and epilogue_functor == EpilogueFunctor.LinearCombination:
self.epilogue_functor = EpilogueFunctor3x.LinearCombination
self.swizzling_functor = swizzling_functor
self.tile_scheduler = tile_scheduler
@@ -709,9 +714,9 @@ class EmitGemmUniversal3xInstance:
]
self.builtin_epilogue_functor_template = """
${epilogue_functor}<
${element_d},
${element_epilogue},
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>
"""
@@ -726,7 +731,8 @@ using ${operation_name}_epilogue =
${element_accumulator}, ${element_epilogue},
${element_c}, ${layout_c}, ${align_c},
${element_d}, ${layout_d}, ${align_d},
${epilogue_schedule}
${epilogue_schedule},
${epilogue_functor}
>::CollectiveOp;
using ${operation_name}_mainloop =
@@ -757,9 +763,11 @@ struct ${operation_name} :
def instance_template(self):
return """
${compile_guard_start}
using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>;
manifest.append(
new ${gemm_kind}<GemmKernel>("${operation_name}"));
{
using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>;
manifest.append(
new ${gemm_kind}<GemmKernel>("${operation_name}"));
}
${compile_guard_end}
"""
@@ -788,9 +796,8 @@ ${compile_guard_end}
# Support built-in epilogue functors or user-defined functions
if isinstance(operation.epilogue_functor, enum.Enum):
values = {
'epilogue_vector_length': str(epilogue_vector_length),
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
'epilogue_functor': EpilogueFunctor3xTag[operation.epilogue_functor],
}
epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values)
else:
@@ -799,6 +806,9 @@ ${compile_guard_end}
element_a = DataTypeTag[operation.A.element]
element_b = DataTypeTag[operation.B.element]
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
element_a = DataTypeTag[operation.A.element]
element_b = DataTypeTag[operation.B.element]
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
values = {
'operation_name': operation.procedural_name(),
'operation_suffix': self.operation_suffix,

View File

@@ -192,14 +192,14 @@ def CreateGemmUniversal3xOperator(
C = TensorDescription(data_type["c_type"], layout[2][0], layout[2][1])
D = TensorDescription(data_type["d_type"], layout[2][0], layout[2][1])
extra_args = {}
gemm_op_extra_args = {}
gemm_kind = GemmKind.Universal3x
element_compute = data_type.get("epi_type", data_type["acc_type"])
operation = GemmOperation(
gemm_kind, tile_description.minimum_compute_capability,
tile_description, A, B, C, element_compute, epilogue_functor, swizzling_functor, D,
kernel_schedule, epilogue_schedule, tile_scheduler, extra_args)
kernel_schedule, epilogue_schedule, tile_scheduler, **gemm_op_extra_args)
manifest.append(operation)
operations.append(operation)

View File

@@ -466,6 +466,13 @@ EpilogueScheduleSuffixes = {
EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
}
class EpilogueFunctor3x(enum.Enum):
LinearCombination = enum_auto()
#
EpilogueFunctor3xTag = {
EpilogueFunctor3x.LinearCombination: 'cutlass::epilogue::fusion::LinearCombination',
}
class TileSchedulerType(enum.Enum):
Default = enum_auto()
Persistent = enum_auto()

View File

@@ -429,7 +429,7 @@ class Manifest:
self.kernel_filter_list = []
else:
self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file)
_LOGGER.info("Using {filter_count} kernel filters from {filter_file}".format(
_LOGGER.debug("Using {filter_count} kernel filters from {filter_file}".format(
filter_count = len(self.kernel_filter_list),
filter_file = args.kernel_filter_file))