mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
CUTLASS 3.4.0 (#1286)
* CUTLASS 3.4.0 * Update CHANGELOG.md --------- Co-authored-by: Pradeep Ramani <prramani@nvidia.com>
This commit is contained in:
@@ -61,7 +61,8 @@ class GemmOperation:
|
||||
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
|
||||
kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto,
|
||||
tile_scheduler = TileSchedulerType.Default, extra_args = None):
|
||||
tile_scheduler = TileSchedulerType.Default
|
||||
):
|
||||
|
||||
kinds_3x = {
|
||||
GemmKind.Universal3x,
|
||||
@@ -88,6 +89,10 @@ class GemmOperation:
|
||||
self.epilogue_schedule = epilogue_schedule
|
||||
self.element_epilogue = element_epilogue
|
||||
self.epilogue_functor = epilogue_functor
|
||||
|
||||
if self.is_3x and epilogue_functor == EpilogueFunctor.LinearCombination:
|
||||
self.epilogue_functor = EpilogueFunctor3x.LinearCombination
|
||||
|
||||
self.swizzling_functor = swizzling_functor
|
||||
self.tile_scheduler = tile_scheduler
|
||||
|
||||
@@ -709,9 +714,9 @@ class EmitGemmUniversal3xInstance:
|
||||
]
|
||||
self.builtin_epilogue_functor_template = """
|
||||
${epilogue_functor}<
|
||||
${element_d},
|
||||
${element_epilogue},
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>
|
||||
"""
|
||||
@@ -726,7 +731,8 @@ using ${operation_name}_epilogue =
|
||||
${element_accumulator}, ${element_epilogue},
|
||||
${element_c}, ${layout_c}, ${align_c},
|
||||
${element_d}, ${layout_d}, ${align_d},
|
||||
${epilogue_schedule}
|
||||
${epilogue_schedule},
|
||||
${epilogue_functor}
|
||||
>::CollectiveOp;
|
||||
|
||||
using ${operation_name}_mainloop =
|
||||
@@ -757,9 +763,11 @@ struct ${operation_name} :
|
||||
def instance_template(self):
|
||||
return """
|
||||
${compile_guard_start}
|
||||
using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>;
|
||||
manifest.append(
|
||||
new ${gemm_kind}<GemmKernel>("${operation_name}"));
|
||||
{
|
||||
using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>;
|
||||
manifest.append(
|
||||
new ${gemm_kind}<GemmKernel>("${operation_name}"));
|
||||
}
|
||||
${compile_guard_end}
|
||||
"""
|
||||
|
||||
@@ -788,9 +796,8 @@ ${compile_guard_end}
|
||||
# Support built-in epilogue functors or user-defined functions
|
||||
if isinstance(operation.epilogue_functor, enum.Enum):
|
||||
values = {
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
'epilogue_functor': EpilogueFunctor3xTag[operation.epilogue_functor],
|
||||
}
|
||||
epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values)
|
||||
else:
|
||||
@@ -799,6 +806,9 @@ ${compile_guard_end}
|
||||
element_a = DataTypeTag[operation.A.element]
|
||||
element_b = DataTypeTag[operation.B.element]
|
||||
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
|
||||
element_a = DataTypeTag[operation.A.element]
|
||||
element_b = DataTypeTag[operation.B.element]
|
||||
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'operation_suffix': self.operation_suffix,
|
||||
|
||||
@@ -192,14 +192,14 @@ def CreateGemmUniversal3xOperator(
|
||||
C = TensorDescription(data_type["c_type"], layout[2][0], layout[2][1])
|
||||
D = TensorDescription(data_type["d_type"], layout[2][0], layout[2][1])
|
||||
|
||||
extra_args = {}
|
||||
gemm_op_extra_args = {}
|
||||
gemm_kind = GemmKind.Universal3x
|
||||
element_compute = data_type.get("epi_type", data_type["acc_type"])
|
||||
|
||||
operation = GemmOperation(
|
||||
gemm_kind, tile_description.minimum_compute_capability,
|
||||
tile_description, A, B, C, element_compute, epilogue_functor, swizzling_functor, D,
|
||||
kernel_schedule, epilogue_schedule, tile_scheduler, extra_args)
|
||||
kernel_schedule, epilogue_schedule, tile_scheduler, **gemm_op_extra_args)
|
||||
|
||||
manifest.append(operation)
|
||||
operations.append(operation)
|
||||
|
||||
@@ -466,6 +466,13 @@ EpilogueScheduleSuffixes = {
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
|
||||
}
|
||||
|
||||
class EpilogueFunctor3x(enum.Enum):
|
||||
LinearCombination = enum_auto()
|
||||
#
|
||||
EpilogueFunctor3xTag = {
|
||||
EpilogueFunctor3x.LinearCombination: 'cutlass::epilogue::fusion::LinearCombination',
|
||||
}
|
||||
|
||||
class TileSchedulerType(enum.Enum):
|
||||
Default = enum_auto()
|
||||
Persistent = enum_auto()
|
||||
|
||||
@@ -429,7 +429,7 @@ class Manifest:
|
||||
self.kernel_filter_list = []
|
||||
else:
|
||||
self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file)
|
||||
_LOGGER.info("Using {filter_count} kernel filters from {filter_file}".format(
|
||||
_LOGGER.debug("Using {filter_count} kernel filters from {filter_file}".format(
|
||||
filter_count = len(self.kernel_filter_list),
|
||||
filter_file = args.kernel_filter_file))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user