mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
CUTLASS 3.2.1 (#1113)
* Updates for 3.2.1 release. * Minor fix in gemm op profiler for raster order. * Add scheduler mapping for raster order in the kernels.
This commit is contained in:
49
python/cutlass_library/__init__.py
Normal file
49
python/cutlass_library/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import sys
|
||||
|
||||
from . import conv2d_operation
|
||||
from . import conv3d_operation
|
||||
from . import gemm_operation
|
||||
|
||||
if '-m' not in sys.argv:
|
||||
# Do not import generator when running python -m cutlass_library.generator to
|
||||
# avoid double-import warnings
|
||||
from . import generator
|
||||
|
||||
from . import library
|
||||
from . import manifest
|
||||
from . import rank_2k_operation
|
||||
from . import rank_k_operation
|
||||
from . import symm_operation
|
||||
from . import trmm_operation
|
||||
492
python/cutlass_library/conv2d_operation.py
Normal file
492
python/cutlass_library/conv2d_operation.py
Normal file
@@ -0,0 +1,492 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utilities for emitting Conv2d kernels
|
||||
"""
|
||||
|
||||
import enum
|
||||
import os.path
|
||||
import shutil
|
||||
|
||||
from cutlass_library.library import *
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class Conv2dOperation:
|
||||
#
|
||||
def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \
|
||||
stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity1, \
|
||||
group_mode = GroupMode.NoneGroup):
|
||||
|
||||
self.operation_kind = OperationKind.Conv2d
|
||||
self.arch = arch
|
||||
self.tile_description = tile_description
|
||||
self.conv_kind = conv_kind
|
||||
self.A = A
|
||||
self.B = B
|
||||
self.C = C
|
||||
self.element_epilogue = element_epilogue
|
||||
self.epilogue_functor = epilogue_functor
|
||||
self.iterator_algorithm = iterator_algorithm
|
||||
self.stride_support = stride_support
|
||||
self.swizzling_functor = swizzling_functor
|
||||
self.group_mode = group_mode
|
||||
#
|
||||
def is_complex(self):
|
||||
complex_operators = [
|
||||
MathOperation.multiply_add_complex,
|
||||
MathOperation.multiply_add_complex_gaussian
|
||||
]
|
||||
return self.tile_description.math_instruction.math_operation in complex_operators
|
||||
|
||||
#
|
||||
def accumulator_type(self):
|
||||
accum = self.tile_description.math_instruction.element_accumulator
|
||||
|
||||
if self.is_complex():
|
||||
return get_complex_from_real(accum)
|
||||
|
||||
return accum
|
||||
|
||||
#
|
||||
def core_name(self):
|
||||
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
||||
|
||||
intermediate_type = ''
|
||||
|
||||
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
|
||||
inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
|
||||
if self.tile_description.math_instruction.element_a != self.A.element and \
|
||||
self.tile_description.math_instruction.element_a != self.accumulator_type():
|
||||
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
||||
else:
|
||||
inst_shape = ''
|
||||
|
||||
return "%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \
|
||||
inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm])
|
||||
|
||||
#
|
||||
def extended_name(self):
|
||||
''' Append data types if they differ from compute type. '''
|
||||
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${element_c}_${core_name}_${element_a}"
|
||||
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${core_name}_${element_a}"
|
||||
else:
|
||||
extended_name = "${core_name}"
|
||||
|
||||
extended_name = SubstituteTemplate(extended_name, {
|
||||
'element_a': DataTypeNames[self.A.element],
|
||||
'element_c': DataTypeNames[self.C.element],
|
||||
'core_name': self.core_name()
|
||||
})
|
||||
|
||||
return extended_name
|
||||
|
||||
#
|
||||
def layout_name(self):
|
||||
return "%s" % (ShortLayoutTypeNames[self.A.layout])
|
||||
|
||||
#
|
||||
def configuration_name(self):
|
||||
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
||||
|
||||
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
||||
|
||||
threadblock = self.tile_description.procedural_name()
|
||||
|
||||
# grouped conv
|
||||
if self.group_mode != GroupMode.NoneGroup:
|
||||
group_conv_name = f"{GroupModeNames[self.group_mode]}_"
|
||||
else:
|
||||
group_conv_name = ""
|
||||
|
||||
if self.stride_support == StrideSupport.Unity:
|
||||
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_${group_conv_name}align${alignment}"
|
||||
else:
|
||||
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${group_conv_name}align${alignment}"
|
||||
|
||||
return SubstituteTemplate(
|
||||
configuration_name,
|
||||
{
|
||||
'opcode_class': opcode_class_name,
|
||||
'extended_name': self.extended_name(),
|
||||
'threadblock': threadblock,
|
||||
'layout': self.layout_name(),
|
||||
'alignment': "%d" % self.A.alignment,
|
||||
'group_conv_name': group_conv_name
|
||||
}
|
||||
)
|
||||
|
||||
#
|
||||
def procedural_name(self):
|
||||
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
||||
return self.configuration_name()
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Emits single instances of a CUTLASS device-wide operator
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
class EmitConv2dInstance:
|
||||
def __init__(self):
|
||||
self.template = """
|
||||
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
|
||||
using ${operation_name}_base =
|
||||
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
|
||||
${element_a},
|
||||
${layout_a},
|
||||
${element_b},
|
||||
${layout_b},
|
||||
${element_c},
|
||||
${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
|
||||
${stages},
|
||||
${math_operator},
|
||||
${iterator_algorithm},
|
||||
${stride_support},
|
||||
${align_a},
|
||||
${align_b}
|
||||
>::Kernel;
|
||||
"""
|
||||
self.template_group_conv = """
|
||||
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
|
||||
using ${operation_name}_base =
|
||||
typename cutlass::conv::kernel::DefaultConv2dGroup${conv_kind_name}<
|
||||
${element_a},
|
||||
${layout_a},
|
||||
${element_b},
|
||||
${layout_b},
|
||||
${element_c},
|
||||
${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
|
||||
${stages},
|
||||
${math_operator},
|
||||
${group_mode},
|
||||
${iterator_algorithm},
|
||||
${stride_support},
|
||||
${align_a},
|
||||
${align_b}
|
||||
>::Kernel;
|
||||
"""
|
||||
self.template_depthwise_direct_conv = """
|
||||
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
|
||||
using ${operation_name}_base =
|
||||
typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConv${conv_kind_name}<
|
||||
${element_a},
|
||||
${layout_a},
|
||||
${element_b},
|
||||
${layout_b},
|
||||
${element_c},
|
||||
${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::conv::TensorNHWCShape<${threadblock_output_shape_n}, ${threadblock_output_shape_p}, ${threadblock_output_shape_q}, ${groups_per_cta}>,
|
||||
cutlass::MatrixShape<${filter_shape_r}, ${filter_shape_s}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue},
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
|
||||
cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
|
||||
1,
|
||||
${threadblock_output_shape_n},
|
||||
${threadblock_output_shape_p},
|
||||
${threadblock_output_shape_q}>,
|
||||
${stages},
|
||||
${math_operator},
|
||||
${iterator_algorithm},
|
||||
${stride_support},
|
||||
cutlass::MatrixShape<${stride_r}, ${stride_s}>,
|
||||
cutlass::MatrixShape<${dilation_r}, ${dilation_s}>
|
||||
>::Kernel;
|
||||
"""
|
||||
|
||||
def emit(self, operation):
|
||||
|
||||
warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
|
||||
|
||||
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'conv_kind': ConvKindTag[operation.conv_kind],
|
||||
'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(),
|
||||
'element_a': DataTypeTag[operation.A.element],
|
||||
'layout_a': LayoutTag[operation.A.layout],
|
||||
'element_b': DataTypeTag[operation.B.element],
|
||||
'layout_b': LayoutTag[operation.B.layout],
|
||||
'element_c': DataTypeTag[operation.C.element],
|
||||
'layout_c': LayoutTag[operation.C.layout],
|
||||
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
||||
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
||||
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
||||
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
||||
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
||||
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
||||
'warp_shape_m': str(warp_shape[0]),
|
||||
'warp_shape_n': str(warp_shape[1]),
|
||||
'warp_shape_k': str(warp_shape[2]),
|
||||
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm],
|
||||
'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
|
||||
'stride_support': StrideSupportTag[operation.stride_support],
|
||||
'math_operator': 'cutlass::arch::OpMultiplyAddComplex' if operation.is_complex() else \
|
||||
MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
||||
'align_a': str(operation.A.alignment),
|
||||
'align_b': str(operation.B.alignment),
|
||||
}
|
||||
|
||||
if operation.group_mode == GroupMode.NoneGroup:
|
||||
return SubstituteTemplate(self.template, values)
|
||||
|
||||
elif operation.group_mode == GroupMode.Depthwise:
|
||||
values['group_mode'] = GroupModeTag[operation.group_mode]
|
||||
# Setup other template params
|
||||
values['threadblock_output_shape_n'] = str(operation.tile_description.threadblock_output_shape[0])
|
||||
values['threadblock_output_shape_p'] = str(operation.tile_description.threadblock_output_shape[1])
|
||||
values['threadblock_output_shape_q'] = str(operation.tile_description.threadblock_output_shape[2])
|
||||
|
||||
values['groups_per_cta'] = str(operation.tile_description.threadblock_output_shape[3])
|
||||
|
||||
values['filter_shape_r'] = str(operation.tile_description.filter_shape[0])
|
||||
values['filter_shape_s'] = str(operation.tile_description.filter_shape[1])
|
||||
|
||||
values['stride_r'] = str(operation.tile_description.stride[0])
|
||||
values['stride_s'] = str(operation.tile_description.stride[1])
|
||||
|
||||
values['dilation_r'] = str(operation.tile_description.dilation[0])
|
||||
values['dilation_s'] = str(operation.tile_description.dilation[1])
|
||||
|
||||
return SubstituteTemplate(self.template_depthwise_direct_conv, values)
|
||||
|
||||
else:
|
||||
values['group_mode'] = GroupModeTag[operation.group_mode]
|
||||
return SubstituteTemplate(self.template_group_conv, values)
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Generator functions for all layouts
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
def GenerateConv2dTensorOp(manifest, tile_descriptions, min_cc, align = 128):
|
||||
|
||||
for tile in tile_descriptions:
|
||||
for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
|
||||
|
||||
if conv_kind == ConvKind.Fprop or (tile.math_instruction.element_accumulator in [DataType.f16, DataType.f32]):
|
||||
|
||||
#
|
||||
output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \
|
||||
if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \
|
||||
else [tile.math_instruction.element_accumulator,]
|
||||
|
||||
for output_type in output_types:
|
||||
A = TensorDescription(tile.math_instruction.element_a, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_a]))
|
||||
B = TensorDescription(tile.math_instruction.element_b, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_b]))
|
||||
C = TensorDescription(output_type, LayoutType.TensorNHWC, max(1, int(align / DataTypeSize[output_type])))
|
||||
|
||||
manifest.append(Conv2dOperation(conv_kind, min_cc, tile, A, B, C, tile.math_instruction.element_accumulator))
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Emitters functions for all targets
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
class EmitConv2dConfigurationLibrary:
|
||||
def __init__(self, operation_path, configuration_name):
|
||||
self.configuration_name = configuration_name
|
||||
self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name)
|
||||
|
||||
self.instance_emitter = EmitConv2dInstance()
|
||||
|
||||
self.instance_template = """
|
||||
${operation_instance}
|
||||
|
||||
// Derived class
|
||||
struct ${operation_name} :
|
||||
public ${operation_name}_base { };
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
"""
|
||||
self.header_template = """
|
||||
/*
|
||||
Generated by conv2d_operation.py - Do not edit.
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "library_internal.h"
|
||||
#include "conv2d_operation.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
"""
|
||||
|
||||
self.configuration_header = """
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
// Initialize all instances
|
||||
void initialize_${configuration_name}(Manifest &manifest) {
|
||||
|
||||
"""
|
||||
|
||||
self.configuration_instance = """
|
||||
using Operation_${operation_name} = cutlass::conv::device::ImplicitGemmConvolution<
|
||||
${operation_name}>;
|
||||
|
||||
manifest.append(new cutlass::library::Conv2dOperation<
|
||||
Operation_${operation_name}>(
|
||||
"${operation_name}"));
|
||||
|
||||
"""
|
||||
|
||||
self.configuration_direct_conv_instance = """
|
||||
using Operation_${operation_name} = cutlass::conv::device::DirectConvolution<
|
||||
${operation_name}>;
|
||||
|
||||
manifest.append(new cutlass::library::DirectConv2dOperation<
|
||||
Operation_${operation_name}>(
|
||||
"${operation_name}"));
|
||||
|
||||
"""
|
||||
|
||||
self.configuration_epilogue = """
|
||||
}
|
||||
"""
|
||||
self.epilogue_template = """
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
"""
|
||||
|
||||
#
|
||||
def __enter__(self):
|
||||
self.configuration_file = open(self.configuration_path, "w")
|
||||
self.configuration_file.write(SubstituteTemplate(self.header_template, {
|
||||
'configuration_name': self.configuration_name
|
||||
}))
|
||||
self.operations = []
|
||||
return self
|
||||
|
||||
#
|
||||
def emit(self, operation):
|
||||
self.operations.append(operation)
|
||||
self.configuration_file.write(SubstituteTemplate(self.instance_template, {
|
||||
'configuration_name': self.configuration_name,
|
||||
'operation_name': operation.procedural_name(),
|
||||
'operation_instance': self.instance_emitter.emit(operation)
|
||||
}))
|
||||
|
||||
#
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
|
||||
self.configuration_file.write(SubstituteTemplate(self.configuration_header, {
|
||||
'configuration_name': self.configuration_name
|
||||
}))
|
||||
|
||||
for operation in self.operations:
|
||||
if operation.group_mode == GroupMode.Depthwise:
|
||||
self.configuration_file.write(SubstituteTemplate(self.configuration_direct_conv_instance, {
|
||||
'configuration_name': self.configuration_name,
|
||||
'operation_name': operation.procedural_name()
|
||||
}))
|
||||
else:
|
||||
self.configuration_file.write(SubstituteTemplate(self.configuration_instance, {
|
||||
'configuration_name': self.configuration_name,
|
||||
'operation_name': operation.procedural_name()
|
||||
}))
|
||||
|
||||
self.configuration_file.write(self.configuration_epilogue)
|
||||
self.configuration_file.write(self.epilogue_template)
|
||||
self.configuration_file.close()
|
||||
|
||||
|
||||
###################################################################################################
|
||||
###################################################################################################
|
||||
350
python/cutlass_library/conv3d_operation.py
Normal file
350
python/cutlass_library/conv3d_operation.py
Normal file
@@ -0,0 +1,350 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utilities for emitting Conv3d kernels
|
||||
"""
|
||||
|
||||
import enum
|
||||
import os.path
|
||||
import shutil
|
||||
|
||||
from cutlass_library.library import *
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class Conv3dOperation:
|
||||
#
|
||||
def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \
|
||||
stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4):
|
||||
|
||||
self.operation_kind = OperationKind.Conv3d
|
||||
self.arch = arch
|
||||
self.tile_description = tile_description
|
||||
self.conv_kind = conv_kind
|
||||
self.A = A
|
||||
self.B = B
|
||||
self.C = C
|
||||
self.element_epilogue = element_epilogue
|
||||
self.epilogue_functor = epilogue_functor
|
||||
self.iterator_algorithm = iterator_algorithm
|
||||
self.stride_support = stride_support
|
||||
self.swizzling_functor = swizzling_functor
|
||||
|
||||
#
|
||||
def core_name(self):
|
||||
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
||||
|
||||
intermediate_type = ''
|
||||
|
||||
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
|
||||
inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
|
||||
if self.tile_description.math_instruction.element_a != self.A.element and \
|
||||
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
|
||||
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
||||
else:
|
||||
inst_shape = ''
|
||||
|
||||
return "%s%s%s%s3d_%s" % (ShortDataTypeNames[self.tile_description.math_instruction.element_accumulator], \
|
||||
inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm])
|
||||
|
||||
#
|
||||
def extended_name(self):
|
||||
''' Append data types if they differ from compute type. '''
|
||||
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${element_c}_${core_name}_${element_a}"
|
||||
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${core_name}_${element_a}"
|
||||
else:
|
||||
extended_name = "${core_name}"
|
||||
|
||||
extended_name = SubstituteTemplate(extended_name, {
|
||||
'element_a': DataTypeNames[self.A.element],
|
||||
'element_c': DataTypeNames[self.C.element],
|
||||
'core_name': self.core_name()
|
||||
})
|
||||
|
||||
return extended_name
|
||||
|
||||
#
|
||||
def configuration_name(self):
|
||||
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
||||
|
||||
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
||||
|
||||
threadblock = "%dx%d_%dx%d" % (
|
||||
self.tile_description.threadblock_shape[0],
|
||||
self.tile_description.threadblock_shape[1],
|
||||
self.tile_description.threadblock_shape[2],
|
||||
self.tile_description.stages
|
||||
)
|
||||
|
||||
if self.stride_support == StrideSupport.Unity:
|
||||
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_unity_stride"
|
||||
else:
|
||||
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}"
|
||||
|
||||
return SubstituteTemplate(
|
||||
configuration_name,
|
||||
{
|
||||
'opcode_class': opcode_class_name,
|
||||
'extended_name': self.extended_name(),
|
||||
'threadblock': threadblock,
|
||||
}
|
||||
)
|
||||
|
||||
#
|
||||
def procedural_name(self):
|
||||
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
||||
return self.configuration_name()
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Emits single instances of a CUTLASS device-wide operator
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
class EmitConv3dInstance:
|
||||
def __init__(self):
|
||||
self.template = """
|
||||
// Conv3d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
|
||||
using ${operation_name}_base =
|
||||
typename cutlass::conv::kernel::DefaultConv3d${conv_kind_name}<
|
||||
${element_a},
|
||||
cutlass::layout::TensorNDHWC,
|
||||
${element_b},
|
||||
cutlass::layout::TensorNDHWC,
|
||||
${element_c},
|
||||
cutlass::layout::TensorNDHWC,
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
|
||||
${stages},
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
${iterator_algorithm},
|
||||
${stride_support}
|
||||
>::Kernel;
|
||||
"""
|
||||
|
||||
|
||||
def emit(self, operation):
|
||||
|
||||
warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
|
||||
|
||||
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'conv_kind': ConvKindTag[operation.conv_kind],
|
||||
'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(),
|
||||
'element_a': DataTypeTag[operation.A.element],
|
||||
'layout_a': LayoutTag[operation.A.layout],
|
||||
'element_b': DataTypeTag[operation.B.element],
|
||||
'layout_b': LayoutTag[operation.B.layout],
|
||||
'element_c': DataTypeTag[operation.C.element],
|
||||
'layout_c': LayoutTag[operation.C.layout],
|
||||
'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
|
||||
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
||||
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
||||
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
||||
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
||||
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
||||
'warp_shape_m': str(warp_shape[0]),
|
||||
'warp_shape_n': str(warp_shape[1]),
|
||||
'warp_shape_k': str(warp_shape[2]),
|
||||
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm],
|
||||
'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
|
||||
'stride_support': StrideSupportTag[operation.stride_support]
|
||||
}
|
||||
|
||||
return SubstituteTemplate(self.template, values)
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Generator functions for all layouts
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
def GenerateConv3dTensorOp(manifest, tile_descriptions, min_cc, align = 128):
|
||||
|
||||
for tile in tile_descriptions:
|
||||
for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
|
||||
|
||||
if conv_kind == ConvKind.Fprop or (tile.math_instruction.element_accumulator in [DataType.f16, DataType.f32]):
|
||||
|
||||
#
|
||||
output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \
|
||||
if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \
|
||||
else [tile.math_instruction.element_accumulator,]
|
||||
|
||||
for output_type in output_types:
|
||||
A = TensorDescription(tile.math_instruction.element_a, LayoutType.TensorNDHWC, int(align / DataTypeSize[tile.math_instruction.element_a]))
|
||||
B = TensorDescription(tile.math_instruction.element_b, LayoutType.TensorNDHWC, int(align / DataTypeSize[tile.math_instruction.element_b]))
|
||||
C = TensorDescription(output_type, LayoutType.TensorNDHWC, max(1, int(align / DataTypeSize[output_type])))
|
||||
|
||||
manifest.append(Conv3dOperation(conv_kind, min_cc, tile, A, B, C, tile.math_instruction.element_accumulator))
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Emitters functions for all targets
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
class EmitConv3dConfigurationLibrary:
|
||||
def __init__(self, operation_path, configuration_name):
|
||||
self.configuration_name = configuration_name
|
||||
self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name)
|
||||
|
||||
self.instance_emitter = EmitConv3dInstance()
|
||||
|
||||
self.instance_template = """
|
||||
${operation_instance}
|
||||
|
||||
// Derived class
|
||||
struct ${operation_name} :
|
||||
public ${operation_name}_base { };
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
"""
|
||||
self.header_template = """
|
||||
/*
|
||||
Generated by conv3d_operation.py - Do not edit.
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "library_internal.h"
|
||||
#include "conv3d_operation.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
"""
|
||||
|
||||
self.configuration_header = """
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
// Initialize all instances
|
||||
void initialize_${configuration_name}(Manifest &manifest) {
|
||||
|
||||
"""
|
||||
|
||||
self.configuration_instance = """
|
||||
using Operation_${operation_name} = cutlass::conv::device::ImplicitGemmConvolution<
|
||||
${operation_name}>;
|
||||
|
||||
manifest.append(new cutlass::library::Conv3dOperation<
|
||||
Operation_${operation_name}>(
|
||||
"${operation_name}"));
|
||||
|
||||
"""
|
||||
|
||||
self.configuration_epilogue = """
|
||||
}
|
||||
"""
|
||||
self.epilogue_template = """
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
"""
|
||||
|
||||
#
|
||||
def __enter__(self):
|
||||
self.configuration_file = open(self.configuration_path, "w")
|
||||
self.configuration_file.write(SubstituteTemplate(self.header_template, {
|
||||
'configuration_name': self.configuration_name
|
||||
}))
|
||||
self.operations = []
|
||||
return self
|
||||
|
||||
#
|
||||
def emit(self, operation):
|
||||
self.operations.append(operation)
|
||||
self.configuration_file.write(SubstituteTemplate(self.instance_template, {
|
||||
'configuration_name': self.configuration_name,
|
||||
'operation_name': operation.procedural_name(),
|
||||
'operation_instance': self.instance_emitter.emit(operation)
|
||||
}))
|
||||
|
||||
#
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
|
||||
self.configuration_file.write(SubstituteTemplate(self.configuration_header, {
|
||||
'configuration_name': self.configuration_name
|
||||
}))
|
||||
|
||||
for operation in self.operations:
|
||||
self.configuration_file.write(SubstituteTemplate(self.configuration_instance, {
|
||||
'configuration_name': self.configuration_name,
|
||||
'operation_name': operation.procedural_name()
|
||||
}))
|
||||
|
||||
self.configuration_file.write(self.configuration_epilogue)
|
||||
self.configuration_file.write(self.epilogue_template)
|
||||
self.configuration_file.close()
|
||||
|
||||
|
||||
###################################################################################################
|
||||
###################################################################################################
|
||||
|
||||
1237
python/cutlass_library/gemm_operation.py
Normal file
1237
python/cutlass_library/gemm_operation.py
Normal file
File diff suppressed because it is too large
Load Diff
5382
python/cutlass_library/generator.py
Normal file
5382
python/cutlass_library/generator.py
Normal file
File diff suppressed because it is too large
Load Diff
990
python/cutlass_library/library.py
Normal file
990
python/cutlass_library/library.py
Normal file
@@ -0,0 +1,990 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Data types and tags used for emitting CUTLASS C++ kernels
|
||||
"""
|
||||
|
||||
import enum
|
||||
import re
|
||||
|
||||
# The following block implements enum.auto() for Python 3.5 variants that don't include it such
|
||||
# as the default 3.5.2 on Ubuntu 16.04.
|
||||
#
|
||||
# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility
|
||||
|
||||
try:
|
||||
from enum import auto as enum_auto
|
||||
except ImportError:
|
||||
__cutlass_library_auto_enum = 0
|
||||
def enum_auto() -> int:
|
||||
global __cutlass_library_auto_enum
|
||||
i = __cutlass_library_auto_enum
|
||||
__cutlass_library_auto_enum += 1
|
||||
return i
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class GeneratorTarget(enum.Enum):
|
||||
Library = enum_auto()
|
||||
#
|
||||
GeneratorTargetNames = {
|
||||
GeneratorTarget.Library: 'library'
|
||||
}
|
||||
#
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class DataType(enum.Enum):
|
||||
void = enum_auto() # primarily used to disable C tensor for epilogues
|
||||
b1 = enum_auto()
|
||||
u4 = enum_auto()
|
||||
u8 = enum_auto()
|
||||
u16 = enum_auto()
|
||||
u32 = enum_auto()
|
||||
u64 = enum_auto()
|
||||
s4 = enum_auto()
|
||||
s8 = enum_auto()
|
||||
s16 = enum_auto()
|
||||
s32 = enum_auto()
|
||||
s64 = enum_auto()
|
||||
e4m3 = enum_auto()
|
||||
e5m2 = enum_auto()
|
||||
f16 = enum_auto()
|
||||
bf16 = enum_auto()
|
||||
f32 = enum_auto()
|
||||
tf32 = enum_auto()
|
||||
f64 = enum_auto()
|
||||
cf16 = enum_auto()
|
||||
cbf16 = enum_auto()
|
||||
cf32 = enum_auto()
|
||||
ctf32 = enum_auto()
|
||||
cf64 = enum_auto()
|
||||
cs4 = enum_auto()
|
||||
cs8 = enum_auto()
|
||||
cs16 = enum_auto()
|
||||
cs32 = enum_auto()
|
||||
cs64 = enum_auto()
|
||||
cu4 = enum_auto()
|
||||
cu8 = enum_auto()
|
||||
cu16 = enum_auto()
|
||||
cu32 = enum_auto()
|
||||
cu64 = enum_auto()
|
||||
invalid = enum_auto()
|
||||
|
||||
#
|
||||
ShortDataTypeNames = {
|
||||
DataType.s32: 'i',
|
||||
DataType.e4m3: 'e4m3',
|
||||
DataType.e5m2: 'e5m2',
|
||||
DataType.f16: 'h',
|
||||
DataType.f32: 's',
|
||||
DataType.f64: 'd',
|
||||
DataType.cf32: 'c',
|
||||
DataType.cf64: 'z',
|
||||
}
|
||||
|
||||
#
|
||||
DataTypeNames = {
|
||||
DataType.void: "void",
|
||||
DataType.b1: "b1",
|
||||
DataType.u4: "u4",
|
||||
DataType.u8: "u8",
|
||||
DataType.u16: "u16",
|
||||
DataType.u32: "u32",
|
||||
DataType.u64: "u64",
|
||||
DataType.s4: "s4",
|
||||
DataType.s8: "s8",
|
||||
DataType.s16: "s16",
|
||||
DataType.s32: "s32",
|
||||
DataType.s64: "s64",
|
||||
DataType.e4m3: 'e4m3',
|
||||
DataType.e5m2: 'e5m2',
|
||||
DataType.f16: "f16",
|
||||
DataType.bf16: "bf16",
|
||||
DataType.f32: "f32",
|
||||
DataType.tf32: "tf32",
|
||||
DataType.f64: "f64",
|
||||
DataType.cf16: "cf16",
|
||||
DataType.cbf16: "cbf16",
|
||||
DataType.cf32: "cf32",
|
||||
DataType.ctf32: "ctf32",
|
||||
DataType.cf64: "cf64",
|
||||
DataType.cu4: "cu4",
|
||||
DataType.cu8: "cu8",
|
||||
DataType.cu16: "cu16",
|
||||
DataType.cu32: "cu32",
|
||||
DataType.cu64: "cu64",
|
||||
DataType.cs4: "cs4",
|
||||
DataType.cs8: "cs8",
|
||||
DataType.cs16: "cs16",
|
||||
DataType.cs32: "cs32",
|
||||
DataType.cs64: "cs64",
|
||||
}
|
||||
|
||||
DataTypeTag = {
|
||||
DataType.void: "void",
|
||||
DataType.b1: "cutlass::uint1b_t",
|
||||
DataType.u4: "cutlass::uint4b_t",
|
||||
DataType.u8: "uint8_t",
|
||||
DataType.u16: "uint16_t",
|
||||
DataType.u32: "uint32_t",
|
||||
DataType.u64: "uint64_t",
|
||||
DataType.s4: "cutlass::int4b_t",
|
||||
DataType.s8: "int8_t",
|
||||
DataType.s16: "int16_t",
|
||||
DataType.s32: "int32_t",
|
||||
DataType.s64: "int64_t",
|
||||
DataType.e4m3: 'cutlass::float_e4m3_t',
|
||||
DataType.e5m2: 'cutlass::float_e5m2_t',
|
||||
DataType.f16: "cutlass::half_t",
|
||||
DataType.bf16: "cutlass::bfloat16_t",
|
||||
DataType.f32: "float",
|
||||
DataType.tf32: "cutlass::tfloat32_t",
|
||||
DataType.f64: "double",
|
||||
DataType.cf16: "cutlass::complex<cutlass::half_t>",
|
||||
DataType.cbf16: "cutlass::complex<cutlass::bfloat16_t>",
|
||||
DataType.cf32: "cutlass::complex<float>",
|
||||
DataType.ctf32: "cutlass::complex<cutlass::tfloat32_t>",
|
||||
DataType.cf64: "cutlass::complex<double>",
|
||||
DataType.cu4: "cutlass::complex<cutlass::uint4b_t>",
|
||||
DataType.cu8: "cutlass::complex<cutlass::uint8_t>",
|
||||
DataType.cu16: "cutlass::complex<cutlass::uint16_t>",
|
||||
DataType.cu32: "cutlass::complex<cutlass::uint32_t>",
|
||||
DataType.cu64: "cutlass::complex<cutlass::uint64_t>",
|
||||
DataType.cs4: "cutlass::complex<cutlass::int4b_t>",
|
||||
DataType.cs8: "cutlass::complex<cutlass::int8_t>",
|
||||
DataType.cs16: "cutlass::complex<cutlass::int16_t>",
|
||||
DataType.cs32: "cutlass::complex<cutlass::int32_t>",
|
||||
DataType.cs64: "cutlass::complex<cutlass::int64_t>",
|
||||
}
|
||||
|
||||
DataTypeSize = {
|
||||
DataType.void: 0,
|
||||
DataType.b1: 1,
|
||||
DataType.u4: 4,
|
||||
DataType.u8: 8,
|
||||
DataType.u16: 16,
|
||||
DataType.u32: 32,
|
||||
DataType.u64: 64,
|
||||
DataType.s4: 4,
|
||||
DataType.s8: 8,
|
||||
DataType.s16: 16,
|
||||
DataType.s32: 32,
|
||||
DataType.s64: 64,
|
||||
DataType.e4m3: 8,
|
||||
DataType.e5m2: 8,
|
||||
DataType.f16: 16,
|
||||
DataType.bf16: 16,
|
||||
DataType.f32: 32,
|
||||
DataType.tf32: 32,
|
||||
DataType.f64: 64,
|
||||
DataType.cf16: 32,
|
||||
DataType.cbf16: 32,
|
||||
DataType.cf32: 64,
|
||||
DataType.ctf32: 32,
|
||||
DataType.cf64: 128,
|
||||
DataType.cu4: 8,
|
||||
DataType.cu8: 16,
|
||||
DataType.cu16: 32,
|
||||
DataType.cu32: 64,
|
||||
DataType.cu64: 128,
|
||||
DataType.cs4: 8,
|
||||
DataType.cs8: 16,
|
||||
DataType.cs16: 32,
|
||||
DataType.cs32: 64,
|
||||
DataType.cs64: 128,
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
class BlasMode(enum.Enum):
|
||||
symmetric = enum_auto()
|
||||
hermitian = enum_auto()
|
||||
|
||||
#
|
||||
BlasModeTag = {
|
||||
BlasMode.symmetric: 'cutlass::BlasMode::kSymmetric',
|
||||
BlasMode.hermitian: 'cutlass::BlasMode::kHermitian',
|
||||
}
|
||||
|
||||
#
|
||||
class ComplexTransform(enum.Enum):
|
||||
none = enum_auto()
|
||||
conj = enum_auto()
|
||||
|
||||
#
|
||||
ComplexTransformTag = {
|
||||
ComplexTransform.none: 'cutlass::ComplexTransform::kNone',
|
||||
ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate',
|
||||
}
|
||||
|
||||
#
|
||||
RealComplexBijection = [
|
||||
(DataType.f16, DataType.cf16),
|
||||
(DataType.f32, DataType.cf32),
|
||||
(DataType.f64, DataType.cf64),
|
||||
]
|
||||
|
||||
#
|
||||
def is_complex(data_type):
|
||||
for r, c in RealComplexBijection:
|
||||
if data_type == c:
|
||||
return True
|
||||
return False
|
||||
|
||||
#
|
||||
def get_complex_from_real(real_type):
|
||||
for r, c in RealComplexBijection:
|
||||
if real_type == r:
|
||||
return c
|
||||
return DataType.invalid
|
||||
|
||||
#
|
||||
def get_real_from_complex(complex_type):
|
||||
for r, c in RealComplexBijection:
|
||||
if complex_type == c:
|
||||
return r
|
||||
return DataType.invalid
|
||||
|
||||
#
|
||||
class ComplexMultiplyOp(enum.Enum):
|
||||
multiply_add = enum_auto()
|
||||
gaussian = enum_auto()
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class MathOperation(enum.Enum):
|
||||
multiply_add = enum_auto()
|
||||
multiply_add_saturate = enum_auto()
|
||||
xor_popc = enum_auto()
|
||||
and_popc = enum_auto()
|
||||
multiply_add_fast_bf16 = enum_auto()
|
||||
multiply_add_fast_f16 = enum_auto()
|
||||
multiply_add_fast_f32 = enum_auto()
|
||||
multiply_add_complex_fast_f32 = enum_auto()
|
||||
multiply_add_complex = enum_auto()
|
||||
multiply_add_complex_gaussian = enum_auto()
|
||||
|
||||
#
|
||||
MathOperationTag = {
|
||||
MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd',
|
||||
MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate',
|
||||
MathOperation.xor_popc: 'cutlass::arch::OpXorPopc',
|
||||
MathOperation.and_popc: 'cutlass::arch::OpAndPopc',
|
||||
MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16',
|
||||
MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16',
|
||||
MathOperation.multiply_add_fast_f32: 'cutlass::arch::OpMultiplyAddFastF32',
|
||||
MathOperation.multiply_add_complex_fast_f32: 'cutlass::arch::OpMultiplyAddComplexFastF32',
|
||||
MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex',
|
||||
MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex',
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class LayoutType(enum.Enum):
|
||||
ColumnMajor = enum_auto()
|
||||
RowMajor = enum_auto()
|
||||
ColumnMajorInterleaved2 = enum_auto()
|
||||
RowMajorInterleaved2 = enum_auto()
|
||||
ColumnMajorInterleaved32 = enum_auto()
|
||||
RowMajorInterleaved32 = enum_auto()
|
||||
ColumnMajorInterleaved64 = enum_auto()
|
||||
RowMajorInterleaved64 = enum_auto()
|
||||
TensorNHWC = enum_auto()
|
||||
TensorNDHWC = enum_auto()
|
||||
TensorNCHW = enum_auto()
|
||||
TensorNGHWC = enum_auto()
|
||||
TensorNC32HW32 = enum_auto()
|
||||
TensorNC64HW64 = enum_auto()
|
||||
TensorC32RSK32 = enum_auto()
|
||||
TensorC64RSK64 = enum_auto()
|
||||
|
||||
#
|
||||
LayoutTag = {
|
||||
LayoutType.ColumnMajor: 'cutlass::layout::ColumnMajor',
|
||||
LayoutType.RowMajor: 'cutlass::layout::RowMajor',
|
||||
LayoutType.ColumnMajorInterleaved2: 'cutlass::layout::ColumnMajorInterleaved<2>',
|
||||
LayoutType.RowMajorInterleaved2: 'cutlass::layout::RowMajorInterleaved<2>',
|
||||
LayoutType.ColumnMajorInterleaved32: 'cutlass::layout::ColumnMajorInterleaved<32>',
|
||||
LayoutType.RowMajorInterleaved32: 'cutlass::layout::RowMajorInterleaved<32>',
|
||||
LayoutType.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>',
|
||||
LayoutType.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>',
|
||||
LayoutType.TensorNHWC: 'cutlass::layout::TensorNHWC',
|
||||
LayoutType.TensorNDHWC: 'cutlass::layout::TensorNDHWC',
|
||||
LayoutType.TensorNCHW: 'cutlass::layout::TensorNCHW',
|
||||
LayoutType.TensorNGHWC: 'cutlass::layout::TensorNGHWC',
|
||||
LayoutType.TensorNC32HW32: 'cutlass::layout::TensorNCxHWx<32>',
|
||||
LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>',
|
||||
LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>',
|
||||
LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>',
|
||||
}
|
||||
|
||||
#
|
||||
TransposedLayout = {
|
||||
LayoutType.ColumnMajor: LayoutType.RowMajor,
|
||||
LayoutType.RowMajor: LayoutType.ColumnMajor,
|
||||
LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2,
|
||||
LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2,
|
||||
LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32,
|
||||
LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32,
|
||||
LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64,
|
||||
LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64,
|
||||
LayoutType.TensorNHWC: LayoutType.TensorNHWC
|
||||
}
|
||||
|
||||
#
|
||||
ShortLayoutTypeNames = {
|
||||
LayoutType.ColumnMajor: 'n',
|
||||
LayoutType.ColumnMajorInterleaved2: 'n2',
|
||||
LayoutType.ColumnMajorInterleaved32: 'n32',
|
||||
LayoutType.ColumnMajorInterleaved64: 'n64',
|
||||
LayoutType.RowMajor: 't',
|
||||
LayoutType.RowMajorInterleaved2: 't2',
|
||||
LayoutType.RowMajorInterleaved32: 't32',
|
||||
LayoutType.RowMajorInterleaved64: 't64',
|
||||
LayoutType.TensorNHWC: 'nhwc',
|
||||
LayoutType.TensorNDHWC: 'ndhwc',
|
||||
LayoutType.TensorNCHW: 'nchw',
|
||||
LayoutType.TensorNGHWC: 'nghwc',
|
||||
LayoutType.TensorNC32HW32: 'nc32hw32',
|
||||
LayoutType.TensorNC64HW64: 'nc64hw64',
|
||||
LayoutType.TensorC32RSK32: 'c32rsk32',
|
||||
LayoutType.TensorC64RSK64: 'c64rsk64'
|
||||
}
|
||||
|
||||
#
|
||||
ShortComplexLayoutNames = {
|
||||
(LayoutType.ColumnMajor, ComplexTransform.none): 'n',
|
||||
(LayoutType.ColumnMajor, ComplexTransform.conj): 'c',
|
||||
(LayoutType.RowMajor, ComplexTransform.none): 't',
|
||||
(LayoutType.RowMajor, ComplexTransform.conj): 'h'
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
class KernelScheduleType(enum.Enum):
|
||||
ScheduleAuto = enum_auto()
|
||||
Multistage = enum_auto()
|
||||
Tma = enum_auto()
|
||||
TmaWarpSpecialized = enum_auto()
|
||||
TmaWarpSpecializedPingpong = enum_auto()
|
||||
TmaWarpSpecializedCooperative = enum_auto()
|
||||
TmaWarpSpecializedFP8FastAccum = enum_auto()
|
||||
TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
|
||||
TmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
|
||||
#
|
||||
KernelScheduleTag = {
|
||||
KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto',
|
||||
KernelScheduleType.Multistage: 'cutlass::gemm::KernelMultistage',
|
||||
KernelScheduleType.Tma: 'cutlass::gemm::KernelTma',
|
||||
KernelScheduleType.TmaWarpSpecialized: 'cutlass::gemm::KernelTmaWarpSpecialized',
|
||||
KernelScheduleType.TmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpong',
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperative',
|
||||
KernelScheduleType.TmaWarpSpecializedFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum',
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum',
|
||||
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum',
|
||||
}
|
||||
|
||||
#
|
||||
KernelScheduleSuffixes = {
|
||||
KernelScheduleType.ScheduleAuto: '',
|
||||
KernelScheduleType.Multistage: '_cpasync',
|
||||
KernelScheduleType.Tma: '_unspecialized',
|
||||
KernelScheduleType.TmaWarpSpecialized: '_warpspecialized',
|
||||
KernelScheduleType.TmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
|
||||
KernelScheduleType.TmaWarpSpecializedFP8FastAccum: '_warpspecialized_fp8_fastaccum',
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
|
||||
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
|
||||
}
|
||||
|
||||
class EpilogueScheduleType(enum.Enum):
|
||||
ScheduleAuto = enum_auto()
|
||||
EpilogueTransposed = enum_auto()
|
||||
NoSmemWarpSpecialized = enum_auto()
|
||||
TmaWarpSpecialized = enum_auto()
|
||||
TmaWarpSpecializedCooperative = enum_auto()
|
||||
#
|
||||
EpilogueScheduleTag = {
|
||||
EpilogueScheduleType.ScheduleAuto: 'cutlass::epilogue::collective::EpilogueScheduleAuto',
|
||||
EpilogueScheduleType.EpilogueTransposed: 'cutlass::gemm::EpilogueTransposed',
|
||||
EpilogueScheduleType.NoSmemWarpSpecialized: 'cutlass::epilogue::NoSmemWarpSpecialized',
|
||||
EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized',
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative',
|
||||
}
|
||||
|
||||
#
|
||||
EpilogueScheduleSuffixes = {
|
||||
EpilogueScheduleType.ScheduleAuto: '',
|
||||
EpilogueScheduleType.EpilogueTransposed: '',
|
||||
EpilogueScheduleType.NoSmemWarpSpecialized: '_epi_nosmem',
|
||||
EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma',
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
|
||||
}
|
||||
|
||||
class TileSchedulerType(enum.Enum):
|
||||
Default = enum_auto()
|
||||
Persistent = enum_auto()
|
||||
StreamK = enum_auto()
|
||||
#
|
||||
TileSchedulerTag = {
|
||||
TileSchedulerType.Default: 'void',
|
||||
TileSchedulerType.Persistent: 'cutlass::gemm::PersistentScheduler',
|
||||
TileSchedulerType.StreamK: 'cutlass::gemm::StreamKScheduler',
|
||||
}
|
||||
|
||||
#
|
||||
TileSchedulerSuffixes = {
|
||||
TileSchedulerType.Default: '',
|
||||
TileSchedulerType.Persistent: '',
|
||||
TileSchedulerType.StreamK: '_stream_k',
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class SideMode(enum.Enum):
|
||||
Left = enum_auto()
|
||||
Right = enum_auto()
|
||||
|
||||
#
|
||||
SideModeTag = {
|
||||
SideMode.Left: 'cutlass::SideMode::kLeft',
|
||||
SideMode.Right: 'cutlass::SideMode::kRight'
|
||||
}
|
||||
|
||||
#
|
||||
ShortSideModeNames = {
|
||||
SideMode.Left: 'ls',
|
||||
SideMode.Right: 'rs'
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class FillMode(enum.Enum):
|
||||
Lower = enum_auto()
|
||||
Upper = enum_auto()
|
||||
|
||||
#
|
||||
FillModeTag = {
|
||||
FillMode.Lower: 'cutlass::FillMode::kLower',
|
||||
FillMode.Upper: 'cutlass::FillMode::kUpper'
|
||||
}
|
||||
|
||||
#
|
||||
ShortFillModeNames = {
|
||||
FillMode.Lower: 'l',
|
||||
FillMode.Upper: 'u'
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class DiagType(enum.Enum):
|
||||
NonUnit = enum_auto()
|
||||
Unit = enum_auto()
|
||||
|
||||
#
|
||||
DiagTypeTag = {
|
||||
DiagType.NonUnit: 'cutlass::DiagType::kNonUnit',
|
||||
DiagType.Unit: 'cutlass::DiagType::kUnit'
|
||||
}
|
||||
|
||||
#
|
||||
ShortDiagTypeNames = {
|
||||
DiagType.NonUnit: 'nu',
|
||||
DiagType.Unit: 'un'
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class OpcodeClass(enum.Enum):
|
||||
Simt = enum_auto()
|
||||
TensorOp = enum_auto()
|
||||
WmmaTensorOp = enum_auto()
|
||||
SparseTensorOp = enum_auto()
|
||||
|
||||
|
||||
OpcodeClassNames = {
|
||||
OpcodeClass.Simt: 'simt',
|
||||
OpcodeClass.TensorOp: 'tensorop',
|
||||
OpcodeClass.WmmaTensorOp: 'wmma_tensorop',
|
||||
}
|
||||
|
||||
OpcodeClassTag = {
|
||||
OpcodeClass.Simt: 'cutlass::arch::OpClassSimt',
|
||||
OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp',
|
||||
OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp',
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class OperationKind(enum.Enum):
|
||||
Gemm = enum_auto()
|
||||
RankK = enum_auto()
|
||||
Rank2K = enum_auto()
|
||||
Trmm = enum_auto()
|
||||
Symm = enum_auto()
|
||||
Conv2d = enum_auto()
|
||||
Conv3d = enum_auto()
|
||||
|
||||
#
|
||||
OperationKindNames = {
|
||||
OperationKind.Gemm: 'gemm'
|
||||
, OperationKind.RankK: 'rank_k'
|
||||
, OperationKind.Rank2K: 'rank_2k'
|
||||
, OperationKind.Trmm: 'trmm'
|
||||
, OperationKind.Symm: 'symm'
|
||||
, OperationKind.Conv2d: 'conv2d'
|
||||
, OperationKind.Conv3d: 'conv3d'
|
||||
}
|
||||
|
||||
#
|
||||
class Target(enum.Enum):
|
||||
library = enum_auto()
|
||||
#
|
||||
ArchitectureNames = {
|
||||
50: 'maxwell',
|
||||
60: 'pascal',
|
||||
61: 'pascal',
|
||||
70: 'volta',
|
||||
75: 'turing',
|
||||
80: 'ampere',
|
||||
89: 'ada',
|
||||
90: 'hopper'
|
||||
}
|
||||
|
||||
#
|
||||
SharedMemPerCC = {
|
||||
70: 96, # 96KB of SMEM
|
||||
72: 96, # 96KB of SMEM
|
||||
75: 64, # 64KB of SMEM
|
||||
80: 163, # 163KB of SMEM - 1KB reserved for the driver
|
||||
86: 99, # 99KB of SMEM - 1KB reserved for the driver
|
||||
87: 163, # 163KB of SMEM - 1KB reserved for the driver
|
||||
89: 99, # 99KB of SMEM - 1KB reserved for the driver
|
||||
90: 227, # 227KB of SMEM - 1KB reserved for the driver
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
def SubstituteTemplate(template, values):
|
||||
text = template
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for key, value in values.items():
|
||||
regex = "\\$\\{%s\\}" % key
|
||||
newtext = re.sub(regex, value, text)
|
||||
if newtext != text:
|
||||
changed = True
|
||||
text = newtext
|
||||
return text
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class GemmKind(enum.Enum):
|
||||
Gemm = enum_auto()
|
||||
Sparse = enum_auto()
|
||||
Universal = enum_auto()
|
||||
Universal3x = enum_auto()
|
||||
PlanarComplex = enum_auto()
|
||||
PlanarComplexArray = enum_auto()
|
||||
Grouped = enum_auto()
|
||||
|
||||
#
|
||||
GemmKindNames = {
|
||||
GemmKind.Gemm: "gemm",
|
||||
GemmKind.Sparse: "spgemm",
|
||||
GemmKind.Universal: "gemm",
|
||||
GemmKind.Universal3x: "gemm",
|
||||
GemmKind.PlanarComplex: "gemm_planar_complex",
|
||||
GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
|
||||
GemmKind.Grouped: "gemm_grouped"
|
||||
}
|
||||
|
||||
#
|
||||
class RankKKind(enum.Enum):
|
||||
Universal = enum_auto()
|
||||
|
||||
#
|
||||
RankKKindNames = {
|
||||
RankKKind.Universal: "rank_k"
|
||||
}
|
||||
|
||||
#
|
||||
class TrmmKind(enum.Enum):
|
||||
Universal = enum_auto()
|
||||
|
||||
#
|
||||
TrmmKindNames = {
|
||||
TrmmKind.Universal: "trmm"
|
||||
}
|
||||
|
||||
#
|
||||
class SymmKind(enum.Enum):
|
||||
Universal = enum_auto()
|
||||
|
||||
#
|
||||
SymmKindNames = {
|
||||
SymmKind.Universal: "symm"
|
||||
}
|
||||
|
||||
#
|
||||
class EpilogueFunctor(enum.Enum):
|
||||
LinearCombination = enum_auto()
|
||||
LinearCombinationClamp = enum_auto()
|
||||
|
||||
#
|
||||
EpilogueFunctorTag = {
|
||||
EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination',
|
||||
EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp',
|
||||
}
|
||||
|
||||
#
|
||||
class SwizzlingFunctor(enum.Enum):
|
||||
Identity1 = enum_auto()
|
||||
Identity2 = enum_auto()
|
||||
Identity4 = enum_auto()
|
||||
Identity8 = enum_auto()
|
||||
Horizontal = enum_auto()
|
||||
StridedDgradIdentity1 = enum_auto()
|
||||
StridedDgradIdentity4 = enum_auto()
|
||||
StridedDgradHorizontal = enum_auto()
|
||||
StreamK = enum_auto()
|
||||
|
||||
#
|
||||
SwizzlingFunctorTag = {
|
||||
SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>',
|
||||
SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>',
|
||||
SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>',
|
||||
SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>',
|
||||
SwizzlingFunctor.Horizontal: 'cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle',
|
||||
SwizzlingFunctor.StridedDgradIdentity1: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>',
|
||||
SwizzlingFunctor.StridedDgradIdentity4: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>',
|
||||
SwizzlingFunctor.StridedDgradHorizontal: 'cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle',
|
||||
SwizzlingFunctor.StreamK: 'cutlass::gemm::threadblock::ThreadblockSwizzleStreamK',
|
||||
}
|
||||
|
||||
#
|
||||
class GroupScheduleMode(enum.Enum):
|
||||
Device = enum_auto(),
|
||||
Host = enum_auto()
|
||||
|
||||
#
|
||||
GroupScheduleModeTag = {
|
||||
GroupScheduleMode.Device: 'cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly',
|
||||
GroupScheduleMode.Host: 'cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute'
|
||||
}
|
||||
|
||||
#
|
||||
ShortGroupScheduleModeNames = {
|
||||
GroupScheduleMode.Device: 'Device',
|
||||
GroupScheduleMode.Host: 'Host'
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class ConvKind(enum.IntEnum):
|
||||
Fprop = 0
|
||||
Dgrad = 1
|
||||
Wgrad = 2
|
||||
|
||||
#
|
||||
ConvKindTag = {
|
||||
ConvKind.Fprop: 'cutlass::conv::Operator::kFprop',
|
||||
ConvKind.Dgrad: 'cutlass::conv::Operator::kDgrad',
|
||||
ConvKind.Wgrad: 'cutlass::conv::Operator::kWgrad'
|
||||
}
|
||||
|
||||
ConvKindNames = {
|
||||
ConvKind.Fprop: 'fprop',
|
||||
ConvKind.Dgrad: 'dgrad',
|
||||
ConvKind.Wgrad: 'wgrad',
|
||||
}
|
||||
|
||||
class ConvMode(enum.IntEnum):
|
||||
CrossCorrelation = 0
|
||||
Convolution = 1
|
||||
|
||||
#
|
||||
class IteratorAlgorithm(enum.Enum):
|
||||
Analytic = 0
|
||||
Optimized = 1
|
||||
FixedChannels = 2
|
||||
FewChannels = 3
|
||||
FixedStrideDilation = 4
|
||||
|
||||
#
|
||||
IteratorAlgorithmTag = {
|
||||
IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic',
|
||||
IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized',
|
||||
IteratorAlgorithm.FixedChannels: 'cutlass::conv::IteratorAlgorithm::kFixedChannels',
|
||||
IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels',
|
||||
IteratorAlgorithm.FixedStrideDilation: 'cutlass::conv::IteratorAlgorithm::kFixedStrideDilation'
|
||||
}
|
||||
|
||||
IteratorAlgorithmNames = {
|
||||
IteratorAlgorithm.Analytic: 'analytic',
|
||||
IteratorAlgorithm.Optimized: 'optimized',
|
||||
IteratorAlgorithm.FixedChannels: 'fixed_channels',
|
||||
IteratorAlgorithm.FewChannels: 'few_channels',
|
||||
IteratorAlgorithm.FixedStrideDilation: 'fixed_stride_dilation'
|
||||
}
|
||||
|
||||
#
|
||||
class StrideSupport(enum.Enum):
|
||||
Strided = 0
|
||||
Unity = 1
|
||||
Fixed = 2
|
||||
|
||||
#
|
||||
StrideSupportTag = {
|
||||
StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided',
|
||||
StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity',
|
||||
StrideSupport.Fixed: 'cutlass::conv::StrideSupport::kFixed'
|
||||
}
|
||||
|
||||
StrideSupportNames = {
|
||||
StrideSupport.Strided: '',
|
||||
StrideSupport.Unity: 'unity_stride',
|
||||
StrideSupport.Fixed: 'fixed_stride'
|
||||
}
|
||||
|
||||
#
|
||||
class GroupMode(enum.Enum):
|
||||
NoneGroup = enum_auto() # dense conv (G=1)
|
||||
SingleGroup = enum_auto() # grouped convolution (single group per CTA)
|
||||
MultipleGroup = enum_auto() # grouped convolution ( multiple groups per CTA)
|
||||
Depthwise = enum_auto() # Depthwise convolution ( C=K=G )
|
||||
|
||||
#
|
||||
GroupModeTag = {
|
||||
GroupMode.NoneGroup: 'cutlass::conv::GroupMode::kNone',
|
||||
GroupMode.SingleGroup: 'cutlass::conv::GroupMode::kSingleGroup',
|
||||
GroupMode.MultipleGroup: 'cutlass::conv::GroupMode::kMultipleGroup',
|
||||
GroupMode.Depthwise: 'cutlass::conv::GroupMode::kDepthwise',
|
||||
}
|
||||
|
||||
GroupModeNames = {
|
||||
GroupMode.NoneGroup: '',
|
||||
GroupMode.SingleGroup: 'single_group',
|
||||
GroupMode.MultipleGroup: 'multiple_group',
|
||||
GroupMode.Depthwise: 'depthwise',
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class MathInstruction:
|
||||
def __init__(self, instruction_shape, element_a, element_b, element_accumulator, opcode_class, math_operation = MathOperation.multiply_add):
|
||||
self.instruction_shape = instruction_shape
|
||||
self.element_a = element_a
|
||||
self.element_b = element_b
|
||||
self.element_accumulator = element_accumulator
|
||||
self.opcode_class = opcode_class
|
||||
self.math_operation = math_operation
|
||||
|
||||
#
|
||||
class TileDescription:
|
||||
|
||||
def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute, cluster_shape = [1,1,1]):
|
||||
self.threadblock_shape = threadblock_shape
|
||||
self.tile_shape = threadblock_shape
|
||||
self.stages = stages
|
||||
self.warp_count = warp_count
|
||||
self.math_instruction = math_instruction
|
||||
self.minimum_compute_capability = min_compute
|
||||
self.maximum_compute_capability = max_compute
|
||||
self.cluster_shape = cluster_shape
|
||||
|
||||
def procedural_name(self):
|
||||
if self.minimum_compute_capability >= 90:
|
||||
return "{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{s}".format(
|
||||
tbm = self.threadblock_shape[0],
|
||||
tbn = self.threadblock_shape[1],
|
||||
tbk = self.threadblock_shape[2],
|
||||
cm = self.cluster_shape[0],
|
||||
cn = self.cluster_shape[1],
|
||||
ck = self.cluster_shape[2],
|
||||
s = self.stages)
|
||||
else:
|
||||
return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
|
||||
|
||||
#
|
||||
class Direct2dConvFixedStrideDilationTileDescription:
|
||||
def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute):
|
||||
self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]]
|
||||
self.threadblock_output_shape = threadblock_output_shape
|
||||
self.filter_shape = filter_shape
|
||||
self.stages = stages
|
||||
self.warp_count = warp_count
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.math_instruction = math_instruction
|
||||
self.minimum_compute_capability = min_compute
|
||||
self.maximum_compute_capability = max_compute
|
||||
|
||||
def procedural_name(self):
|
||||
str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0],
|
||||
self.threadblock_shape[1],
|
||||
self.threadblock_shape[2],
|
||||
self.threadblock_output_shape[0],
|
||||
self.threadblock_output_shape[1],
|
||||
self.threadblock_output_shape[2],
|
||||
self.threadblock_output_shape[3],
|
||||
self.stages,
|
||||
self.filter_shape[0],
|
||||
self.filter_shape[1])
|
||||
# Fixed Strided and dilation
|
||||
if self.stride != [-1, -1] and self.dilation != [-1, -1]:
|
||||
str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0],
|
||||
self.stride[1],
|
||||
self.dilation[0],
|
||||
self.dilation[1])
|
||||
return str_name
|
||||
|
||||
#
|
||||
class Direct2dConvFixedStrideDilationTileDescription:
|
||||
def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute):
|
||||
self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]]
|
||||
self.threadblock_output_shape = threadblock_output_shape
|
||||
self.filter_shape = filter_shape
|
||||
self.stages = stages
|
||||
self.warp_count = warp_count
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.math_instruction = math_instruction
|
||||
self.minimum_compute_capability = min_compute
|
||||
self.maximum_compute_capability = max_compute
|
||||
|
||||
def procedural_name(self):
|
||||
str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0],
|
||||
self.threadblock_shape[1],
|
||||
self.threadblock_shape[2],
|
||||
self.threadblock_output_shape[0],
|
||||
self.threadblock_output_shape[1],
|
||||
self.threadblock_output_shape[2],
|
||||
self.threadblock_output_shape[3],
|
||||
self.stages,
|
||||
self.filter_shape[0],
|
||||
self.filter_shape[1])
|
||||
# Fixed Strided and dilation
|
||||
if self.stride != [-1, -1] and self.dilation != [-1, -1]:
|
||||
str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0],
|
||||
self.stride[1],
|
||||
self.dilation[0],
|
||||
self.dilation[1])
|
||||
return str_name
|
||||
|
||||
#
|
||||
class TensorDescription:
|
||||
def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none):
|
||||
self.element = element
|
||||
self.layout = layout
|
||||
self.alignment = alignment
|
||||
self.complex_transform = complex_transform
|
||||
|
||||
#
|
||||
class SymmetricTensorDescription:
|
||||
def __init__(self, element, layout, fill_mode, alignment = 1, complex_transform = ComplexTransform.none, side_mode = SideMode.Left):
|
||||
self.element = element
|
||||
self.layout = layout
|
||||
self.fill_mode = fill_mode
|
||||
self.alignment = alignment
|
||||
self.complex_transform = complex_transform
|
||||
self.side_mode = side_mode
|
||||
|
||||
#
|
||||
class TriangularTensorDescription:
|
||||
def __init__(self, element, layout, side_mode, fill_mode, diag_type, alignment = 1, complex_transform = ComplexTransform.none):
|
||||
self.element = element
|
||||
self.layout = layout
|
||||
self.side_mode = side_mode
|
||||
self.fill_mode = fill_mode
|
||||
self.diag_type = diag_type
|
||||
self.alignment = alignment
|
||||
self.complex_transform = complex_transform
|
||||
|
||||
#
|
||||
def CalculateSmemUsage(operation):
|
||||
cta_shape = operation.tile_description.threadblock_shape
|
||||
stages = operation.tile_description.stages
|
||||
|
||||
if operation.operation_kind == OperationKind.Gemm and operation.gemm_kind == GemmKind.Sparse:
|
||||
# Elements represented by 8 bits of metadata (based on 4:8, 2:4 or 1:2 sparsity)
|
||||
if DataTypeSize[operation.A.element] == 32:
|
||||
elements_per_8b_md = 2
|
||||
elif DataTypeSize[operation.A.element] == 4:
|
||||
elements_per_8b_md = 8
|
||||
else:
|
||||
elements_per_8b_md = 4
|
||||
|
||||
smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * (cta_shape[2] // 2) // 8 + \
|
||||
DataTypeSize[operation.B.element] * cta_shape[1] * cta_shape[2] // 8 + \
|
||||
cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md
|
||||
else:
|
||||
# Few BLAS3 operations only have A tensor
|
||||
smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * cta_shape[2] // 8 + \
|
||||
DataTypeSize[operation.A.element] * cta_shape[1] * cta_shape[2] // 8
|
||||
|
||||
smem_usage = smem_per_stage * stages
|
||||
return (smem_usage >> 10)
|
||||
|
||||
|
||||
class GemmUniversalMode(enum.IntEnum):
|
||||
"""
|
||||
Types corresponding to GemmUniversalMode
|
||||
"""
|
||||
Gemm = 0
|
||||
GemmSplitKParallel = 1
|
||||
Batched = 2
|
||||
Array = 3
|
||||
|
||||
|
||||
class SplitKMode(enum.IntEnum):
|
||||
"""
|
||||
Types corresponding to SplitKMode
|
||||
"""
|
||||
NoneSplitK = 0
|
||||
Serial = 1
|
||||
Parallel = 2
|
||||
683
python/cutlass_library/manifest.py
Normal file
683
python/cutlass_library/manifest.py
Normal file
@@ -0,0 +1,683 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utilities for filtering CUTLASS library kernels and emitting library intitialization
|
||||
and building code
|
||||
"""
|
||||
|
||||
import enum
|
||||
import os.path
|
||||
import shutil
|
||||
|
||||
from cutlass_library.library import *
|
||||
from cutlass_library.gemm_operation import *
|
||||
from cutlass_library.rank_k_operation import *
|
||||
from cutlass_library.rank_2k_operation import *
|
||||
from cutlass_library.trmm_operation import *
|
||||
from cutlass_library.symm_operation import *
|
||||
from cutlass_library.conv2d_operation import *
|
||||
from cutlass_library.conv3d_operation import *
|
||||
import logging
|
||||
|
||||
###################################################################################################
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmitOperationKindAll:
|
||||
def __init__(self, generated_path, kind, args):
|
||||
self.generated_path = generated_path
|
||||
self.kind = kind
|
||||
self.args = args
|
||||
|
||||
self.header_template ="""
|
||||
/*
|
||||
Generated by manifest.py - Do not edit.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
"""
|
||||
|
||||
self.entry_template = """
|
||||
|
||||
//
|
||||
// Entry point to construct operations
|
||||
//
|
||||
void initialize_all_${operation_name}_operations(Manifest &manifest) {
|
||||
"""
|
||||
self.configuration_prototype_template = "void initialize_${configuration_name}(Manifest &manifest);\n"
|
||||
self.configuration_template =" initialize_${configuration_name}(manifest);\n"
|
||||
|
||||
self.epilogue_template ="""}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
"""
|
||||
|
||||
#
|
||||
def __enter__(self):
|
||||
self.operation_path = os.path.join(self.generated_path, OperationKindNames[self.kind])
|
||||
os.makedirs(self.operation_path, exist_ok=True)
|
||||
|
||||
self.top_level_path = os.path.join(self.operation_path, f"all_{OperationKindNames[self.kind]}_operations.cu")
|
||||
|
||||
self.top_level_file = open(self.top_level_path, "w")
|
||||
self.top_level_file.write(self.header_template)
|
||||
|
||||
self.source_files = [self.top_level_path,]
|
||||
|
||||
self.configurations = []
|
||||
|
||||
return self
|
||||
|
||||
#
|
||||
def emit(self, operations):
|
||||
for min_cc, configurations in sorted(operations.items()):
|
||||
for configuration_name, _ in configurations.items():
|
||||
self.configurations.append(configuration_name)
|
||||
self.top_level_file.write(SubstituteTemplate(self.configuration_prototype_template, {'configuration_name': configuration_name} ))
|
||||
|
||||
#
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
self.top_level_file.write(SubstituteTemplate(self.entry_template, {'operation_name': OperationKindNames[self.kind]}))
|
||||
|
||||
for configuration_name in self.configurations:
|
||||
self.top_level_file.write(SubstituteTemplate(self.configuration_template, {'configuration_name': configuration_name}))
|
||||
|
||||
self.top_level_file.write(self.epilogue_template)
|
||||
self.top_level_file.close()
|
||||
|
||||
|
||||
class EmitOperationKindLibrary:
|
||||
def __init__(self, generated_path, min_cc, kind, args):
|
||||
self.generated_path = generated_path
|
||||
self.min_cc = min_cc
|
||||
self.kind = kind
|
||||
self.args = args
|
||||
self.emitters = {
|
||||
OperationKind.Gemm: EmitGemmConfigurationLibrary,
|
||||
OperationKind.Conv2d: EmitConv2dConfigurationLibrary,
|
||||
OperationKind.Conv3d: EmitConv3dConfigurationLibrary,
|
||||
OperationKind.RankK: EmitRankKConfigurationLibrary,
|
||||
OperationKind.Rank2K: EmitRank2KConfigurationLibrary,
|
||||
OperationKind.Trmm: EmitTrmmConfigurationLibrary,
|
||||
OperationKind.Symm: EmitSymmConfigurationLibrary
|
||||
}
|
||||
|
||||
self.header_template ="""
|
||||
/*
|
||||
Generated by manifest.py - Do not edit.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
"""
|
||||
self.entry_template = """
|
||||
|
||||
//
|
||||
// Entry point to construct operations
|
||||
//
|
||||
void initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(Manifest &manifest) {
|
||||
"""
|
||||
self.configuration_prototype_template = "void initialize_${configuration_name}(Manifest &manifest);\n"
|
||||
self.configuration_template = " initialize_${configuration_name}(manifest);\n"
|
||||
self.subclass_call_template = " initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(manifest);\n"
|
||||
|
||||
self.epilogue_template ="""}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
"""
|
||||
|
||||
#
|
||||
def __enter__(self):
|
||||
self.operation_path = os.path.join(self.generated_path, OperationKindNames[self.kind], str(self.min_cc))
|
||||
os.makedirs(self.operation_path)
|
||||
|
||||
self.top_level_path = os.path.join(self.operation_path, f"all_sm{self.min_cc}_{OperationKindNames[self.kind]}_operations.cu")
|
||||
|
||||
self.top_level_file = open(self.top_level_path, "w")
|
||||
self.top_level_file.write(self.header_template)
|
||||
|
||||
self.source_files = {}
|
||||
|
||||
# Each {operation_kind x cc} combination is further decomposed by the instruction
|
||||
# types used. This dictionary used to track the file handles for the top-level
|
||||
# files of each subclass
|
||||
self.subclass_files = {}
|
||||
|
||||
# Configurations in each sub class
|
||||
self.subclass_configurations = {}
|
||||
|
||||
return self
|
||||
|
||||
#
|
||||
def emit(self, configuration_name, operations):
|
||||
assert len(operations) > 0
|
||||
|
||||
# The extended name for all operations of a given configuration_name is guaranteed
|
||||
# to be the same because extended_name() is used in defining configuration_name. Thus,
|
||||
# we can safely use the extended_name() of the first operation.
|
||||
extended_name = operations[0].extended_name()
|
||||
|
||||
# Create a directory for operations with this subclass if it does not exist
|
||||
if extended_name not in self.subclass_files:
|
||||
subclass_path = os.path.join(self.operation_path, extended_name)
|
||||
os.mkdir(subclass_path)
|
||||
|
||||
self.subclass_configurations[extended_name] = []
|
||||
|
||||
# Open a new top-level file for this sub class
|
||||
subclass_top_level_path = os.path.join(
|
||||
subclass_path, f"all_sm{self.min_cc}_{extended_name}_{OperationKindNames[self.kind]}_operations.cu")
|
||||
self.subclass_files[extended_name] = open(subclass_top_level_path, "w")
|
||||
self.subclass_files[extended_name].write(self.header_template)
|
||||
|
||||
self.source_files[extended_name] = [subclass_top_level_path]
|
||||
|
||||
subclass_dir = os.path.dirname(self.subclass_files[extended_name].name)
|
||||
with self.emitters[self.kind](subclass_dir, configuration_name) as configuration_emitter:
|
||||
for operation in operations:
|
||||
configuration_emitter.emit(operation)
|
||||
|
||||
self.source_files[extended_name].append(configuration_emitter.configuration_path)
|
||||
|
||||
self.subclass_configurations[extended_name].append(configuration_name)
|
||||
self.subclass_files[extended_name].write(SubstituteTemplate(self.configuration_prototype_template, {'configuration_name': configuration_name} ))
|
||||
|
||||
#
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
|
||||
self.top_level_file.write(
|
||||
SubstituteTemplate(self.entry_template, {
|
||||
'min_cc': str(self.min_cc),
|
||||
'subclass_name': '',
|
||||
'operation_name': OperationKindNames[self.kind]
|
||||
}))
|
||||
|
||||
# Finish and close all subclass files
|
||||
for subclass_name, subclass_file in sorted(self.subclass_files.items()):
|
||||
subclass_cfg = {
|
||||
'min_cc': str(self.min_cc),
|
||||
'subclass_name': subclass_name,
|
||||
'operation_name': OperationKindNames[self.kind]
|
||||
}
|
||||
subclass_file.write(SubstituteTemplate(self.entry_template, subclass_cfg))
|
||||
|
||||
for configuration in self.subclass_configurations[subclass_name]:
|
||||
subclass_file.write(
|
||||
SubstituteTemplate(self.configuration_template, {
|
||||
'configuration_name': configuration
|
||||
}))
|
||||
|
||||
subclass_file.write(self.epilogue_template)
|
||||
subclass_file.close()
|
||||
|
||||
# Write the call to initialize_all for this subclass to the top-level file
|
||||
self.top_level_file.write(SubstituteTemplate(self.subclass_call_template, subclass_cfg))
|
||||
|
||||
self.top_level_file.write(self.epilogue_template)
|
||||
self.top_level_file.close()
|
||||
|
||||
class EmitInterfaceLibrary:
|
||||
def __init__(self, generated_path, operation_count, args):
|
||||
self.generated_path = generated_path
|
||||
self.args = args
|
||||
|
||||
self.prototypes = []
|
||||
self.fn_calls = []
|
||||
self.operation_count = str(operation_count)
|
||||
|
||||
self.top_level_hdr_template = '''
|
||||
/*
|
||||
Generated by manifest.py - Do not edit.
|
||||
*/
|
||||
'''
|
||||
self.top_level_prologue = '''
|
||||
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
namespace cutlass {
|
||||
\tnamespace library {
|
||||
|
||||
${prototypes}
|
||||
'''
|
||||
|
||||
self.top_level_initialize_kind = '''
|
||||
\t\tvoid initialize_all_${kind}_operations(Manifest &manifest) {
|
||||
${fn_calls}
|
||||
\t\t}
|
||||
'''
|
||||
|
||||
self.top_level_initialize = '''
|
||||
\t\tvoid initialize_all(Manifest &manifest) {
|
||||
\t\t\tmanifest.reserve(${operation_count});\n
|
||||
${fn_calls}
|
||||
\t\t}
|
||||
'''
|
||||
|
||||
self.top_level_suffix = '''
|
||||
\t} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
'''
|
||||
|
||||
#
|
||||
def __enter__(self):
|
||||
self.top_level_path = os.path.join(self.generated_path, 'initialize_all.cpp')
|
||||
|
||||
self.top_level_file = open(self.top_level_path, "w")
|
||||
self.top_level_file.write(self.top_level_hdr_template)
|
||||
|
||||
self.source_files = [self.top_level_path,]
|
||||
|
||||
return self
|
||||
|
||||
#
|
||||
def emit(self, operation_name):
|
||||
self.prototypes.append(SubstituteTemplate(
|
||||
"\t\tvoid initialize_all_${operation_kind}_operations(Manifest &manifest);",
|
||||
{'operation_kind': operation_name}))
|
||||
|
||||
self.fn_calls.append(SubstituteTemplate(
|
||||
"\t\t\tinitialize_all_${operation_kind}_operations(manifest);",
|
||||
{'operation_kind': operation_name}))
|
||||
|
||||
#
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
self.top_level_file.write(SubstituteTemplate(self.top_level_prologue, {'prototypes':"\n".join(self.prototypes)}))
|
||||
|
||||
# Write out initialize_all method
|
||||
self.top_level_file.write(SubstituteTemplate(self.top_level_initialize,
|
||||
{'operation_count': self.operation_count, 'fn_calls':"\n".join(self.fn_calls)}))
|
||||
|
||||
self.top_level_file.write(self.top_level_suffix)
|
||||
self.top_level_file.close()
|
||||
|
||||
###################################################################################################
|
||||
###################################################################################################
|
||||
|
||||
class Options:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class Manifest:
|
||||
|
||||
#
|
||||
def __init__(self, args = None):
|
||||
self.operations = {}
|
||||
self.args = args
|
||||
self.operation_count = 0
|
||||
self.operations_by_name = {}
|
||||
|
||||
self.kernel_filter = ''
|
||||
self.kernel_filter_list = []
|
||||
self.kernel_names = []
|
||||
self.operations_enabled = []
|
||||
self.selected_kernels = []
|
||||
self.ignore_kernel_names = []
|
||||
self.compute_capabilities = [50,]
|
||||
self.curr_build_dir = '.'
|
||||
self.filter_by_cc = True
|
||||
|
||||
if self.args:
|
||||
self.kernel_filter = self.args.kernels
|
||||
self.curr_build_dir = args.curr_build_dir
|
||||
|
||||
architectures = args.architectures.split(';') if len(args.architectures) else ['50',]
|
||||
architectures = [x if x != '90a' else '90' for x in architectures]
|
||||
|
||||
self.compute_capabilities = [int(x) for x in architectures]
|
||||
|
||||
if args.filter_by_cc in ['false', 'False', '0']:
|
||||
self.filter_by_cc = False
|
||||
|
||||
if args.operations == 'all':
|
||||
self.operations_enabled = []
|
||||
else:
|
||||
operations_list = [
|
||||
OperationKind.Gemm
|
||||
, OperationKind.Conv2d
|
||||
, OperationKind.Conv3d
|
||||
, OperationKind.RankK
|
||||
, OperationKind.Trmm
|
||||
, OperationKind.Symm
|
||||
]
|
||||
self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')]
|
||||
|
||||
if args.kernels == 'all':
|
||||
self.kernel_names = []
|
||||
else:
|
||||
self.kernel_names = [x for x in args.kernels.split(',') if x != '']
|
||||
|
||||
self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != '']
|
||||
|
||||
if args.kernel_filter_file is None:
|
||||
self.kernel_filter_list = []
|
||||
else:
|
||||
self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file)
|
||||
_LOGGER.info("Using {filter_count} kernel filters from {filter_file}".format(
|
||||
filter_count = len(self.kernel_filter_list),
|
||||
filter_file = args.kernel_filter_file))
|
||||
|
||||
self.operation_count = 0
|
||||
self.operations_by_name = {}
|
||||
self.disable_full_archs_compilation = args.disable_full_archs_compilation
|
||||
|
||||
|
||||
def get_kernel_filters (self, kernelListFile):
|
||||
if os.path.isfile(kernelListFile):
|
||||
with open(kernelListFile, 'r') as fileReader:
|
||||
lines = [line.rstrip() for line in fileReader if not line.startswith("#")]
|
||||
|
||||
lines = [re.compile(line) for line in lines if line]
|
||||
return lines
|
||||
else:
|
||||
return []
|
||||
|
||||
#
|
||||
def filter_out_kernels(self, kernel_name, kernel_filter_list):
|
||||
|
||||
for kernel_filter_re in kernel_filter_list:
|
||||
if kernel_filter_re.search(kernel_name) is not None:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
#
|
||||
def _filter_string_matches(self, filter_string, haystack):
|
||||
''' Returns true if all substrings appear in the haystack in order'''
|
||||
substrings = filter_string.split('*')
|
||||
for sub in substrings:
|
||||
idx = haystack.find(sub)
|
||||
if idx < 0:
|
||||
return False
|
||||
haystack = haystack[idx + len(sub):]
|
||||
return True
|
||||
|
||||
#
|
||||
def filter(self, operation):
|
||||
''' Filtering operations based on various criteria'''
|
||||
|
||||
# filter based on compute capability
|
||||
enabled = not (self.filter_by_cc)
|
||||
|
||||
for cc in self.compute_capabilities:
|
||||
if cc >= operation.tile_description.minimum_compute_capability and \
|
||||
cc <= operation.tile_description.maximum_compute_capability and \
|
||||
(cc not in SharedMemPerCC or SharedMemPerCC[cc] >= CalculateSmemUsage(operation)):
|
||||
|
||||
enabled = True
|
||||
break
|
||||
|
||||
if not enabled:
|
||||
return False
|
||||
|
||||
if len(self.operations_enabled) and not operation.operation_kind in self.operations_enabled:
|
||||
return False
|
||||
|
||||
# eliminate duplicates
|
||||
if operation.procedural_name() in self.operations_by_name.keys():
|
||||
return False
|
||||
|
||||
# Filter based on list of valid substrings
|
||||
if len(self.kernel_names):
|
||||
name = operation.procedural_name()
|
||||
enabled = False
|
||||
|
||||
# compare against the include list
|
||||
for name_substr in self.kernel_names:
|
||||
if self._filter_string_matches(name_substr, name):
|
||||
_LOGGER.debug("Kernel {kernel} included due to filter string '{filt}'.".format(
|
||||
kernel = operation.procedural_name(),
|
||||
filt = name_substr))
|
||||
enabled = True
|
||||
break
|
||||
|
||||
# compare against the exclude list
|
||||
for name_substr in self.ignore_kernel_names:
|
||||
if self._filter_string_matches(name_substr, name):
|
||||
_LOGGER.debug("Kernel {kernel} ignored due to filter string '{filt}'.".format(
|
||||
kernel = operation.procedural_name(),
|
||||
filt = name_substr))
|
||||
enabled = False
|
||||
break
|
||||
|
||||
if len(self.kernel_filter_list) > 0:
|
||||
if self.filter_out_kernels(operation.procedural_name(), self.kernel_filter_list):
|
||||
_LOGGER.debug("Kernel {kernel} matched via kernel filter file.".format(kernel = operation.procedural_name()))
|
||||
enabled = True
|
||||
else:
|
||||
_LOGGER.debug("Kernel {kernel} culled due to no match in kernel filter file.".format(kernel = operation.procedural_name()))
|
||||
enabled = False
|
||||
|
||||
|
||||
# TODO: filter based on compute data type
|
||||
return enabled
|
||||
#
|
||||
|
||||
#
|
||||
def append(self, operation):
|
||||
'''
|
||||
Inserts the operation.
|
||||
|
||||
operation_kind -> configuration_name -> []
|
||||
'''
|
||||
|
||||
if self.filter(operation):
|
||||
|
||||
self.selected_kernels.append(operation.procedural_name())
|
||||
|
||||
self.operations_by_name[operation.procedural_name()] = operation
|
||||
|
||||
# add the configuration
|
||||
configuration_name = operation.configuration_name()
|
||||
|
||||
# Split operations by minimum CC
|
||||
min_cc = operation.arch
|
||||
|
||||
if operation.operation_kind not in self.operations.keys():
|
||||
self.operations[operation.operation_kind] = {}
|
||||
|
||||
if min_cc not in self.operations[operation.operation_kind]:
|
||||
self.operations[operation.operation_kind][min_cc] = {}
|
||||
|
||||
if configuration_name not in self.operations[operation.operation_kind][min_cc].keys():
|
||||
self.operations[operation.operation_kind][min_cc][configuration_name] = []
|
||||
|
||||
self.operations[operation.operation_kind][min_cc][configuration_name].append(operation)
|
||||
self.operation_count += 1
|
||||
else:
|
||||
_LOGGER.debug("Culled {} from manifest".format(operation.procedural_name()))
|
||||
#
|
||||
|
||||
def emit_manifest_cmake(self, manifest_path, top_level_path, source_files):
|
||||
with open(manifest_path, "w") as manifest_file:
|
||||
|
||||
target_text = SubstituteTemplate("""cutlass_target_sources(cutlass_library_objs PRIVATE
|
||||
""", { })
|
||||
manifest_file.write(target_text + '\n\n')
|
||||
manifest_file.write(" %s\n" % str(top_level_path.replace('\\', '/')))
|
||||
generated_path = os.path.join(self.curr_build_dir, 'generated')
|
||||
for kind in self.operations.keys():
|
||||
kind_str = OperationKindNames[kind]
|
||||
all_kind_file = os.path.join(generated_path, kind_str, f"all_{kind_str}_operations.cu").replace('\\', '/')
|
||||
manifest_file.write(f" {all_kind_file}\n")
|
||||
manifest_file.write(')\n\n')
|
||||
|
||||
for kind in self.operations.keys():
|
||||
for min_cc in sorted(self.operations[kind].keys()):
|
||||
for subclass in sorted(source_files[kind][min_cc].keys()):
|
||||
target_text = SubstituteTemplate("""cutlass_add_cutlass_library(
|
||||
SUFFIX ${kind}_sm${min_cc}_${subclass}
|
||||
""", { 'min_cc': str(min_cc), 'kind': OperationKindNames[kind], 'subclass': subclass })
|
||||
manifest_file.write(target_text + '\n\n')
|
||||
|
||||
for source_file in source_files[kind][min_cc][subclass]:
|
||||
manifest_file.write(" %s\n" % str(source_file.replace('\\', '/')))
|
||||
|
||||
manifest_file.write(")\n")
|
||||
|
||||
if self.disable_full_archs_compilation:
|
||||
self.emit_disable_full_archs_compilation(manifest_file, source_files)
|
||||
|
||||
def emit_disable_full_archs_compilation(manifest_file, source_files):
|
||||
def for_hopper(name):
|
||||
pass
|
||||
|
||||
def for_ampere(name):
|
||||
return "16816" in name or \
|
||||
"16832" in name or \
|
||||
"16864" in name or \
|
||||
("1688" in name and "tf32" in name)
|
||||
|
||||
def for_turing(name):
|
||||
return ("1688" in name and "tf32" not in name) or \
|
||||
"8816" in name
|
||||
|
||||
def for_volta(name):
|
||||
return "884" in name
|
||||
|
||||
def is_cpp(name):
|
||||
return name.endswith(".cpp")
|
||||
|
||||
def get_src_archs_str_given_requested_cuda_archs(archs, source_file):
|
||||
intersected_archs = archs & set(self.compute_capabilities)
|
||||
if intersected_archs == set():
|
||||
raise RuntimeError(
|
||||
"""
|
||||
Empty archs set for file {} after taking
|
||||
the intersection of {} (global requested archs) and
|
||||
{} (per file requested archs)
|
||||
""".format(source_file, set(self.compute_capabilities), archs))
|
||||
else:
|
||||
return " ".join(map(str, intersected_archs))
|
||||
|
||||
for min_cc in sorted(source_files.keys()):
|
||||
for source_file in source_files[min_cc]:
|
||||
if is_cpp(source_file):
|
||||
continue # skip because source is cpp
|
||||
elif for_ampere(source_file):
|
||||
archs_str = get_src_archs_str_given_requested_cuda_archs({80, 87, 90}, source_file)
|
||||
elif for_turing(source_file):
|
||||
archs_str = get_src_archs_str_given_requested_cuda_archs({75}, source_file)
|
||||
elif for_volta(source_file):
|
||||
archs_str = get_src_archs_str_given_requested_cuda_archs({70, 72}, source_file)
|
||||
else:
|
||||
raise RuntimeError("Per file archs are not set {}, as there is no rule specified for this file pattern".format(source_file))
|
||||
|
||||
manifest_file.write("cutlass_apply_cuda_gencode_flags({} SM_ARCHS {})\n".format(str(source_file.replace('\\', '/')), archs_str))
|
||||
|
||||
#
|
||||
def emit(self, target = GeneratorTarget.Library):
|
||||
|
||||
operation_emitters = {
|
||||
GeneratorTarget.Library: EmitOperationKindLibrary
|
||||
}
|
||||
|
||||
# Emitters for all operations that fall under a particular kind (e.g., GEMM, Conv2d)
|
||||
kind_emitters = {
|
||||
GeneratorTarget.Library: EmitOperationKindAll
|
||||
}
|
||||
|
||||
interface_emitters = {
|
||||
GeneratorTarget.Library: EmitInterfaceLibrary
|
||||
}
|
||||
|
||||
generated_path = os.path.join(self.curr_build_dir, 'generated')
|
||||
|
||||
# create generated/
|
||||
if os.path.exists(generated_path):
|
||||
shutil.rmtree(generated_path)
|
||||
|
||||
os.mkdir(generated_path)
|
||||
|
||||
with interface_emitters[target](generated_path, self.operation_count, self.args) as iface_emitter:
|
||||
top_level_path = iface_emitter.top_level_path
|
||||
for operation_kind in self.operations.keys():
|
||||
iface_emitter.emit(OperationKindNames[operation_kind])
|
||||
|
||||
source_files = {}
|
||||
for kind in self.operations.keys():
|
||||
source_files[kind] = {}
|
||||
for min_cc in self.operations[kind].keys():
|
||||
source_files[kind][min_cc] = {}
|
||||
|
||||
for operation_kind, ops in self.operations.items():
|
||||
for min_cc, configurations in sorted(ops.items()):
|
||||
with operation_emitters[target](generated_path, min_cc, operation_kind, self.args) as operation_kind_emitter:
|
||||
for configuration_name, operations in configurations.items():
|
||||
_LOGGER.info("Emitting {config} with {num_ops} operations.".format(
|
||||
config = configuration_name, num_ops = len(operations)))
|
||||
operation_kind_emitter.emit(configuration_name, operations)
|
||||
|
||||
for subclass, files in operation_kind_emitter.source_files.items():
|
||||
if subclass not in source_files[operation_kind][min_cc]:
|
||||
source_files[operation_kind][min_cc][subclass] = []
|
||||
source_files[operation_kind][min_cc][subclass].extend(operation_kind_emitter.source_files[subclass])
|
||||
|
||||
# Emit top level all_{gemm, conv2d, ...}_operations.cu files
|
||||
with kind_emitters[target](generated_path, operation_kind, self.args) as operation_kind_emitter:
|
||||
operation_kind_emitter.emit(ops)
|
||||
|
||||
# write the manifest.cmake file containing paths from all targets
|
||||
manifest_path = os.path.join(generated_path, "manifest.cmake")
|
||||
|
||||
self.emit_manifest_cmake(manifest_path, top_level_path, source_files)
|
||||
|
||||
###################################################################################################
|
||||
428
python/cutlass_library/rank_2k_operation.py
Normal file
428
python/cutlass_library/rank_2k_operation.py
Normal file
@@ -0,0 +1,428 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utilities for emitting Rank2K kernels
|
||||
"""
|
||||
|
||||
import enum
|
||||
import os.path
|
||||
import shutil
|
||||
import functools
|
||||
import operator
|
||||
|
||||
from cutlass_library.library import *
|
||||
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Data structure modeling a Rank K update operation
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class Rank2KOperation:
|
||||
#
|
||||
def __init__(self, rank_k_kind, arch, tile_description, A, C, element_epilogue, \
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \
|
||||
blas_mode = BlasMode.symmetric):
|
||||
|
||||
self.blas_mode = blas_mode
|
||||
self.operation_kind = OperationKind.Rank2K
|
||||
self.arch = arch
|
||||
self.tile_description = tile_description
|
||||
self.rank_k_kind = rank_k_kind
|
||||
# tensor A and B have same data type and layout
|
||||
self.A = A
|
||||
self.B = A
|
||||
self.C = C
|
||||
self.element_epilogue = element_epilogue
|
||||
self.epilogue_functor = epilogue_functor
|
||||
self.swizzling_functor = swizzling_functor
|
||||
|
||||
#
|
||||
def is_complex(self):
|
||||
complex_operators = [
|
||||
MathOperation.multiply_add_complex,
|
||||
MathOperation.multiply_add_complex_gaussian,
|
||||
MathOperation.multiply_add_complex_fast_f32
|
||||
]
|
||||
return self.tile_description.math_instruction.math_operation in complex_operators
|
||||
return False
|
||||
|
||||
#
|
||||
def is_planar_complex(self):
|
||||
return False
|
||||
|
||||
#
|
||||
def accumulator_type(self):
|
||||
accum = self.tile_description.math_instruction.element_accumulator
|
||||
|
||||
if self.is_complex():
|
||||
return get_complex_from_real(accum)
|
||||
|
||||
return accum
|
||||
|
||||
#
|
||||
def short_math_name(self):
|
||||
if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
|
||||
return "g%s" % ShortDataTypeNames[self.accumulator_type()]
|
||||
return ShortDataTypeNames[self.accumulator_type()]
|
||||
|
||||
|
||||
#
|
||||
def core_name(self):
|
||||
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
||||
|
||||
inst_shape = ''
|
||||
inst_operation = ''
|
||||
intermediate_type = ''
|
||||
|
||||
math_operations_map = {
|
||||
MathOperation.xor_popc: 'xor',
|
||||
MathOperation.and_popc: 'and'
|
||||
}
|
||||
|
||||
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \
|
||||
self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp:
|
||||
|
||||
math_op = self.tile_description.math_instruction.math_operation
|
||||
math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
|
||||
|
||||
inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
|
||||
inst_shape += math_op_string
|
||||
|
||||
if self.tile_description.math_instruction.element_a != self.A.element and \
|
||||
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
|
||||
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
||||
|
||||
operation_name = 'syr2k' if self.blas_mode == BlasMode.symmetric else 'her2k'
|
||||
|
||||
return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, operation_name)
|
||||
|
||||
#
|
||||
def extended_name(self):
|
||||
''' Append data types if they differ from compute type. '''
|
||||
if self.is_complex():
|
||||
extended_name = "${core_name}"
|
||||
else:
|
||||
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${element_c}_${core_name}_${element_a}"
|
||||
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${core_name}_${element_a}"
|
||||
else:
|
||||
extended_name = "${core_name}"
|
||||
|
||||
extended_name = SubstituteTemplate(extended_name, {
|
||||
'element_a': DataTypeNames[self.A.element],
|
||||
'element_c': DataTypeNames[self.C.element],
|
||||
'core_name': self.core_name()
|
||||
})
|
||||
|
||||
return extended_name
|
||||
|
||||
#
|
||||
def layout_name(self):
|
||||
if self.is_complex() or self.is_planar_complex():
|
||||
return "%s" % (
|
||||
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)]
|
||||
)
|
||||
return "%s" % (ShortLayoutTypeNames[self.A.layout])
|
||||
|
||||
#
|
||||
def fill_mode_name(self):
|
||||
return "%s" % (ShortFillModeNames[self.C.fill_mode])
|
||||
|
||||
#
|
||||
def procedural_name(self):
|
||||
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
||||
threadblock = self.tile_description.procedural_name()
|
||||
|
||||
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
||||
|
||||
alignment = max([self.A.alignment, self.C.alignment])
|
||||
|
||||
return SubstituteTemplate(
|
||||
"cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${fill_mode}_align${alignment}",
|
||||
{
|
||||
'opcode_class': opcode_class_name,
|
||||
'extended_name': self.extended_name(),
|
||||
'threadblock': threadblock,
|
||||
'layout': self.layout_name(),
|
||||
'fill_mode': self.fill_mode_name(),
|
||||
'alignment': "%d" % self.A.alignment,
|
||||
}
|
||||
)
|
||||
|
||||
#
|
||||
def configuration_name(self):
|
||||
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
||||
return self.procedural_name()
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Emits single instances of a CUTLASS device-wide operator
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class EmitRank2KUniversalInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self):
|
||||
self.rank_k_template = """
|
||||
// Rank K operator ${operation_name}
|
||||
using Operation_${operation_name} =
|
||||
typename cutlass::gemm::device::Rank2K<
|
||||
${element_a}, ${layout_a},
|
||||
${element_b}, ${layout_b},
|
||||
${element_c}, ${layout_c}, ${fill_mode},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${align_a},
|
||||
${align_b},
|
||||
${split_k_serial},
|
||||
${math_operation}
|
||||
>;
|
||||
"""
|
||||
self.rank_k_complex_template = """
|
||||
// Rank K operator ${operation_name}
|
||||
using Operation_${operation_name} =
|
||||
typename cutlass::gemm::device::Rank2K<
|
||||
${element_a}, ${layout_a},
|
||||
${element_b}, ${layout_b},
|
||||
${element_c}, ${layout_c}, ${fill_mode},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${align_a},
|
||||
${align_b},
|
||||
${split_k_serial},
|
||||
${math_operation},
|
||||
${transform_a},
|
||||
${transform_b},
|
||||
${blas_mode}
|
||||
>;
|
||||
"""
|
||||
|
||||
def emit(self, operation):
|
||||
|
||||
threadblock_shape = operation.tile_description.threadblock_shape
|
||||
|
||||
warp_count = operation.tile_description.warp_count
|
||||
warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
|
||||
|
||||
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'element_a': DataTypeTag[operation.A.element],
|
||||
'layout_a': LayoutTag[operation.A.layout],
|
||||
'element_b': DataTypeTag[operation.B.element],
|
||||
'layout_b': LayoutTag[operation.B.layout],
|
||||
'element_c': DataTypeTag[operation.C.element],
|
||||
'layout_c': LayoutTag[operation.C.layout],
|
||||
'fill_mode': FillModeTag[operation.C.fill_mode],
|
||||
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
||||
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
||||
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
||||
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
||||
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
||||
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
||||
'warp_shape_m': str(warp_shape[0]),
|
||||
'warp_shape_n': str(warp_shape[1]),
|
||||
'warp_shape_k': str(warp_shape[2]),
|
||||
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'align_a': str(operation.A.alignment),
|
||||
'align_b': str(operation.B.alignment),
|
||||
'split_k_serial': 'false',
|
||||
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
||||
'transform_a': ComplexTransformTag[operation.A.complex_transform],
|
||||
'transform_b': ComplexTransformTag[operation.B.complex_transform],
|
||||
'blas_mode': BlasModeTag[operation.blas_mode]
|
||||
}
|
||||
|
||||
rank_k_template = self.rank_k_complex_template if operation.is_complex() else self.rank_k_template
|
||||
|
||||
return SubstituteTemplate(rank_k_template, values)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Emitters functions for all targets
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
class EmitRank2KConfigurationLibrary:
|
||||
def __init__(self, operation_path, configuration_name):
|
||||
self.configuration_name = configuration_name
|
||||
self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/')
|
||||
|
||||
self.instance_emitter = {
|
||||
RankKKind.Universal: EmitRank2KUniversalInstance,
|
||||
}
|
||||
|
||||
self.rank_k_kind_wrappers = {
|
||||
RankKKind.Universal: 'Rank2KOperation',
|
||||
}
|
||||
|
||||
self.instance_template = {
|
||||
RankKKind.Universal: """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${rank_k_kind}<
|
||||
Operation_${operation_name}
|
||||
>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
"""
|
||||
}
|
||||
|
||||
self.header_template = """
|
||||
/*
|
||||
Generated by rank_2k_operation.py - Do not edit.
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "library_internal.h"
|
||||
#include "rank_2k_operation.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
"""
|
||||
|
||||
self.initialize_function_template = """
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void initialize_${configuration_name}(Manifest &manifest) {
|
||||
|
||||
"""
|
||||
self.epilogue_template = """
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
self.configuration_file = open(self.configuration_path, "w")
|
||||
self.configuration_file.write(self.header_template)
|
||||
|
||||
self.instance_definitions = []
|
||||
self.instance_wrappers = []
|
||||
|
||||
self.operations = []
|
||||
return self
|
||||
|
||||
def emit(self, operation):
|
||||
emitter = self.instance_emitter[operation.rank_k_kind]()
|
||||
|
||||
self.operations.append(operation)
|
||||
|
||||
self.instance_definitions.append(emitter.emit(operation))
|
||||
|
||||
self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.rank_k_kind], {
|
||||
'configuration_name': self.configuration_name,
|
||||
'operation_name': operation.procedural_name(),
|
||||
'rank_k_kind': self.rank_k_kind_wrappers[operation.rank_k_kind],
|
||||
'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
|
||||
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
|
||||
'compile_guard_end': "#endif" \
|
||||
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
|
||||
}))
|
||||
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
|
||||
# Write instance definitions in top-level namespace
|
||||
for instance_definition in self.instance_definitions:
|
||||
self.configuration_file.write(instance_definition)
|
||||
|
||||
# Add wrapper objects within initialize() function
|
||||
self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
|
||||
'configuration_name': self.configuration_name
|
||||
}))
|
||||
|
||||
for instance_wrapper in self.instance_wrappers:
|
||||
self.configuration_file.write(instance_wrapper)
|
||||
|
||||
self.configuration_file.write(self.epilogue_template)
|
||||
self.configuration_file.close()
|
||||
|
||||
###################################################################################################
|
||||
417
python/cutlass_library/rank_k_operation.py
Normal file
417
python/cutlass_library/rank_k_operation.py
Normal file
@@ -0,0 +1,417 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utilities for emitting RankK kernels
|
||||
"""
|
||||
|
||||
import enum
|
||||
import os.path
|
||||
import shutil
|
||||
import functools
|
||||
import operator
|
||||
|
||||
from cutlass_library.library import *
|
||||
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Data structure modeling a Rank K update operation
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class RankKOperation:
|
||||
#
|
||||
def __init__(self, rank_k_kind, arch, tile_description, A, C, element_epilogue, \
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \
|
||||
blas_mode = BlasMode.symmetric):
|
||||
|
||||
self.blas_mode = blas_mode
|
||||
self.operation_kind = OperationKind.RankK
|
||||
self.arch = arch
|
||||
self.tile_description = tile_description
|
||||
self.rank_k_kind = rank_k_kind
|
||||
self.A = A
|
||||
self.C = C
|
||||
self.element_epilogue = element_epilogue
|
||||
self.epilogue_functor = epilogue_functor
|
||||
self.swizzling_functor = swizzling_functor
|
||||
|
||||
#
|
||||
def is_complex(self):
|
||||
complex_operators = [
|
||||
MathOperation.multiply_add_complex,
|
||||
MathOperation.multiply_add_complex_gaussian,
|
||||
MathOperation.multiply_add_complex_fast_f32
|
||||
]
|
||||
return self.tile_description.math_instruction.math_operation in complex_operators
|
||||
return False
|
||||
|
||||
#
|
||||
def is_planar_complex(self):
|
||||
return False
|
||||
|
||||
#
|
||||
def accumulator_type(self):
|
||||
accum = self.tile_description.math_instruction.element_accumulator
|
||||
|
||||
if self.is_complex():
|
||||
return get_complex_from_real(accum)
|
||||
|
||||
return accum
|
||||
|
||||
#
|
||||
def short_math_name(self):
|
||||
if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
|
||||
return "g%s" % ShortDataTypeNames[self.accumulator_type()]
|
||||
return ShortDataTypeNames[self.accumulator_type()]
|
||||
|
||||
|
||||
#
|
||||
def core_name(self):
|
||||
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
||||
|
||||
inst_shape = ''
|
||||
inst_operation = ''
|
||||
intermediate_type = ''
|
||||
|
||||
math_operations_map = {
|
||||
MathOperation.xor_popc: 'xor',
|
||||
MathOperation.and_popc: 'and'
|
||||
}
|
||||
|
||||
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \
|
||||
self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp:
|
||||
|
||||
math_op = self.tile_description.math_instruction.math_operation
|
||||
math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
|
||||
|
||||
inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
|
||||
inst_shape += math_op_string
|
||||
|
||||
if self.tile_description.math_instruction.element_a != self.A.element and \
|
||||
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
|
||||
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
||||
|
||||
operation_name = 'syrk' if self.blas_mode == BlasMode.symmetric else 'herk'
|
||||
|
||||
return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, operation_name)
|
||||
|
||||
#
|
||||
def extended_name(self):
|
||||
''' Append data types if they differ from compute type. '''
|
||||
if self.is_complex():
|
||||
extended_name = "${core_name}"
|
||||
else:
|
||||
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${element_c}_${core_name}_${element_a}"
|
||||
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${core_name}_${element_a}"
|
||||
else:
|
||||
extended_name = "${core_name}"
|
||||
|
||||
extended_name = SubstituteTemplate(extended_name, {
|
||||
'element_a': DataTypeNames[self.A.element],
|
||||
'element_c': DataTypeNames[self.C.element],
|
||||
'core_name': self.core_name()
|
||||
})
|
||||
|
||||
return extended_name
|
||||
|
||||
#
|
||||
def layout_name(self):
|
||||
if self.is_complex() or self.is_planar_complex():
|
||||
return "%s" % (
|
||||
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)]
|
||||
)
|
||||
return "%s" % (ShortLayoutTypeNames[self.A.layout])
|
||||
|
||||
#
|
||||
def fill_mode_name(self):
|
||||
return "%s" % (ShortFillModeNames[self.C.fill_mode])
|
||||
|
||||
#
|
||||
def procedural_name(self):
|
||||
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
||||
threadblock = self.tile_description.procedural_name()
|
||||
|
||||
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
||||
|
||||
alignment = max([self.A.alignment, self.C.alignment])
|
||||
|
||||
return SubstituteTemplate(
|
||||
"cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${fill_mode}_align${alignment}",
|
||||
{
|
||||
'opcode_class': opcode_class_name,
|
||||
'extended_name': self.extended_name(),
|
||||
'threadblock': threadblock,
|
||||
'layout': self.layout_name(),
|
||||
'fill_mode': self.fill_mode_name(),
|
||||
'alignment': "%d" % self.A.alignment,
|
||||
}
|
||||
)
|
||||
|
||||
#
|
||||
def configuration_name(self):
|
||||
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
||||
return self.procedural_name()
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Emits single instances of a CUTLASS device-wide operator
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class EmitRankKUniversalInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self):
|
||||
self.rank_k_template = """
|
||||
// Rank K operator ${operation_name}
|
||||
using Operation_${operation_name} =
|
||||
typename cutlass::gemm::device::RankK<
|
||||
${element_a}, ${layout_a},
|
||||
${element_c}, ${layout_c}, ${fill_mode},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${align_a},
|
||||
${split_k_serial},
|
||||
${math_operation}
|
||||
>;
|
||||
"""
|
||||
self.rank_k_complex_template = """
|
||||
// Rank K operator ${operation_name}
|
||||
using Operation_${operation_name} =
|
||||
typename cutlass::gemm::device::RankK<
|
||||
${element_a}, ${layout_a},
|
||||
${element_c}, ${layout_c}, ${fill_mode},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${align_a},
|
||||
${split_k_serial},
|
||||
${math_operation},
|
||||
${transform_a},
|
||||
${blas_mode}
|
||||
>;
|
||||
"""
|
||||
|
||||
def emit(self, operation):
|
||||
|
||||
threadblock_shape = operation.tile_description.threadblock_shape
|
||||
|
||||
warp_count = operation.tile_description.warp_count
|
||||
warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
|
||||
|
||||
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'element_a': DataTypeTag[operation.A.element],
|
||||
'layout_a': LayoutTag[operation.A.layout],
|
||||
'element_c': DataTypeTag[operation.C.element],
|
||||
'layout_c': LayoutTag[operation.C.layout],
|
||||
'fill_mode': FillModeTag[operation.C.fill_mode],
|
||||
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
||||
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
||||
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
||||
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
||||
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
||||
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
||||
'warp_shape_m': str(warp_shape[0]),
|
||||
'warp_shape_n': str(warp_shape[1]),
|
||||
'warp_shape_k': str(warp_shape[2]),
|
||||
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'align_a': str(operation.A.alignment),
|
||||
'split_k_serial': 'false',
|
||||
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
||||
'transform_a': ComplexTransformTag[operation.A.complex_transform],
|
||||
'blas_mode': BlasModeTag[operation.blas_mode]
|
||||
}
|
||||
|
||||
rank_k_template = self.rank_k_complex_template if operation.is_complex() else self.rank_k_template
|
||||
|
||||
return SubstituteTemplate(rank_k_template, values)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Emitters functions for all targets
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
class EmitRankKConfigurationLibrary:
|
||||
def __init__(self, operation_path, configuration_name):
|
||||
self.configuration_name = configuration_name
|
||||
self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/')
|
||||
|
||||
self.instance_emitter = {
|
||||
RankKKind.Universal: EmitRankKUniversalInstance,
|
||||
}
|
||||
|
||||
self.rank_k_kind_wrappers = {
|
||||
RankKKind.Universal: 'RankKOperation',
|
||||
}
|
||||
|
||||
self.instance_template = {
|
||||
RankKKind.Universal: """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${rank_k_kind}<
|
||||
Operation_${operation_name}
|
||||
>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
"""
|
||||
}
|
||||
|
||||
self.header_template = """
|
||||
/*
|
||||
Generated by rank_k_operation.py - Do not edit.
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "library_internal.h"
|
||||
#include "rank_k_operation.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
"""
|
||||
|
||||
self.initialize_function_template = """
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void initialize_${configuration_name}(Manifest &manifest) {
|
||||
|
||||
"""
|
||||
self.epilogue_template = """
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
self.configuration_file = open(self.configuration_path, "w")
|
||||
self.configuration_file.write(self.header_template)
|
||||
|
||||
self.instance_definitions = []
|
||||
self.instance_wrappers = []
|
||||
|
||||
self.operations = []
|
||||
return self
|
||||
|
||||
def emit(self, operation):
|
||||
emitter = self.instance_emitter[operation.rank_k_kind]()
|
||||
|
||||
self.operations.append(operation)
|
||||
|
||||
self.instance_definitions.append(emitter.emit(operation))
|
||||
|
||||
self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.rank_k_kind], {
|
||||
'configuration_name': self.configuration_name,
|
||||
'operation_name': operation.procedural_name(),
|
||||
'rank_k_kind': self.rank_k_kind_wrappers[operation.rank_k_kind],
|
||||
'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
|
||||
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
|
||||
'compile_guard_end': "#endif" \
|
||||
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
|
||||
}))
|
||||
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
|
||||
# Write instance definitions in top-level namespace
|
||||
for instance_definition in self.instance_definitions:
|
||||
self.configuration_file.write(instance_definition)
|
||||
|
||||
# Add wrapper objects within initialize() function
|
||||
self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
|
||||
'configuration_name': self.configuration_name
|
||||
}))
|
||||
|
||||
for instance_wrapper in self.instance_wrappers:
|
||||
self.configuration_file.write(instance_wrapper)
|
||||
|
||||
self.configuration_file.write(self.epilogue_template)
|
||||
self.configuration_file.close()
|
||||
|
||||
###################################################################################################
|
||||
430
python/cutlass_library/symm_operation.py
Normal file
430
python/cutlass_library/symm_operation.py
Normal file
@@ -0,0 +1,430 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utilities for emitting Symm kernels
|
||||
"""
|
||||
|
||||
import enum
|
||||
import os.path
|
||||
import shutil
|
||||
import functools
|
||||
import operator
|
||||
|
||||
from cutlass_library.library import *
|
||||
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Data structure modeling a Symm update operation
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class SymmOperation:
|
||||
#
|
||||
def __init__(self, symm_kind, arch, tile_description, A, B, C, element_epilogue, \
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \
|
||||
blas_mode = BlasMode.symmetric):
|
||||
|
||||
self.blas_mode = blas_mode
|
||||
self.operation_kind = OperationKind.Symm
|
||||
self.arch = arch
|
||||
self.tile_description = tile_description
|
||||
self.symm_kind = symm_kind
|
||||
# tensor A and B have same data type and layout
|
||||
self.A = A
|
||||
self.B = B
|
||||
self.C = C
|
||||
self.element_epilogue = element_epilogue
|
||||
self.epilogue_functor = epilogue_functor
|
||||
self.swizzling_functor = swizzling_functor
|
||||
|
||||
#
|
||||
def is_complex(self):
|
||||
complex_operators = [
|
||||
MathOperation.multiply_add_complex,
|
||||
MathOperation.multiply_add_complex_gaussian,
|
||||
MathOperation.multiply_add_complex_fast_f32
|
||||
]
|
||||
return self.tile_description.math_instruction.math_operation in complex_operators
|
||||
return False
|
||||
|
||||
#
|
||||
def is_planar_complex(self):
|
||||
return False
|
||||
|
||||
#
|
||||
def accumulator_type(self):
|
||||
accum = self.tile_description.math_instruction.element_accumulator
|
||||
|
||||
if self.is_complex():
|
||||
return get_complex_from_real(accum)
|
||||
|
||||
return accum
|
||||
|
||||
#
|
||||
def short_math_name(self):
|
||||
if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
|
||||
return "g%s" % ShortDataTypeNames[self.accumulator_type()]
|
||||
return ShortDataTypeNames[self.accumulator_type()]
|
||||
|
||||
|
||||
#
|
||||
def core_name(self):
|
||||
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
||||
|
||||
inst_shape = ''
|
||||
inst_operation = ''
|
||||
intermediate_type = ''
|
||||
|
||||
math_operations_map = {
|
||||
MathOperation.xor_popc: 'xor',
|
||||
MathOperation.and_popc: 'and'
|
||||
}
|
||||
|
||||
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \
|
||||
self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp:
|
||||
|
||||
math_op = self.tile_description.math_instruction.math_operation
|
||||
math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
|
||||
|
||||
inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
|
||||
inst_shape += math_op_string
|
||||
|
||||
if self.tile_description.math_instruction.element_a != self.A.element and \
|
||||
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
|
||||
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
||||
|
||||
operation_name = 'symm' if self.blas_mode == BlasMode.symmetric else 'hemm'
|
||||
|
||||
return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, operation_name)
|
||||
|
||||
#
|
||||
def extended_name(self):
|
||||
''' Append data types if they differ from compute type. '''
|
||||
if self.is_complex():
|
||||
extended_name = "${core_name}"
|
||||
else:
|
||||
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${element_c}_${core_name}_${element_a}"
|
||||
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${core_name}_${element_a}"
|
||||
else:
|
||||
extended_name = "${core_name}"
|
||||
|
||||
extended_name = SubstituteTemplate(extended_name, {
|
||||
'element_a': DataTypeNames[self.A.element],
|
||||
'element_c': DataTypeNames[self.C.element],
|
||||
'core_name': self.core_name()
|
||||
})
|
||||
|
||||
return extended_name
|
||||
|
||||
#
|
||||
def layout_name(self):
|
||||
if self.is_complex() or self.is_planar_complex():
|
||||
return "%s" % (
|
||||
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)]
|
||||
)
|
||||
return "%s" % (ShortLayoutTypeNames[self.A.layout])
|
||||
|
||||
#
|
||||
def side_mode_name(self):
|
||||
return "%s" % (ShortSideModeNames[self.A.side_mode])
|
||||
|
||||
#
|
||||
def fill_mode_name(self):
|
||||
return "%s" % (ShortFillModeNames[self.A.fill_mode])
|
||||
|
||||
#
|
||||
def procedural_name(self):
|
||||
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
||||
threadblock = self.tile_description.procedural_name()
|
||||
|
||||
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
||||
|
||||
alignment = self.C.alignment
|
||||
|
||||
return SubstituteTemplate(
|
||||
"cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${side_mode}_${fill_mode}_align${alignment}",
|
||||
{
|
||||
'opcode_class': opcode_class_name,
|
||||
'extended_name': self.extended_name(),
|
||||
'threadblock': threadblock,
|
||||
'layout': self.layout_name(),
|
||||
'side_mode': self.side_mode_name(),
|
||||
'fill_mode': self.fill_mode_name(),
|
||||
'alignment': "%d" % alignment,
|
||||
}
|
||||
)
|
||||
|
||||
#
|
||||
def configuration_name(self):
|
||||
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
||||
return self.procedural_name()
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Emits single instances of a CUTLASS device-wide operator
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class EmitSymmUniversalInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self):
|
||||
self.symm_template = """
|
||||
// Symm operator ${operation_name}
|
||||
using Operation_${operation_name} =
|
||||
typename cutlass::gemm::device::Symm<
|
||||
${element_a}, ${layout_a}, ${side_mode}, ${fill_mode},
|
||||
${element_b}, ${layout_b},
|
||||
${element_c}, ${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${align_a},
|
||||
${align_b},
|
||||
${split_k_serial},
|
||||
${math_operation}
|
||||
>;
|
||||
"""
|
||||
self.symm_complex_template = """
|
||||
// Symm operator ${operation_name}
|
||||
using Operation_${operation_name} =
|
||||
typename cutlass::gemm::device::Symm<
|
||||
${element_a}, ${layout_a}, ${side_mode}, ${fill_mode},
|
||||
${element_b}, ${layout_b},
|
||||
${element_c}, ${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue}
|
||||
>,
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${align_a},
|
||||
${align_b},
|
||||
${split_k_serial},
|
||||
${math_operation},
|
||||
${blas_mode}
|
||||
>;
|
||||
"""
|
||||
|
||||
def emit(self, operation):
|
||||
|
||||
threadblock_shape = operation.tile_description.threadblock_shape
|
||||
|
||||
warp_count = operation.tile_description.warp_count
|
||||
warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
|
||||
|
||||
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'element_a': DataTypeTag[operation.A.element],
|
||||
'layout_a': LayoutTag[operation.A.layout],
|
||||
'side_mode': SideModeTag[operation.A.side_mode],
|
||||
'fill_mode': FillModeTag[operation.A.fill_mode],
|
||||
'element_b': DataTypeTag[operation.B.element],
|
||||
'layout_b': LayoutTag[operation.B.layout],
|
||||
'element_c': DataTypeTag[operation.C.element],
|
||||
'layout_c': LayoutTag[operation.C.layout],
|
||||
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
||||
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
||||
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
||||
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
||||
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
||||
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
||||
'warp_shape_m': str(warp_shape[0]),
|
||||
'warp_shape_n': str(warp_shape[1]),
|
||||
'warp_shape_k': str(warp_shape[2]),
|
||||
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'align_a': str(operation.A.alignment),
|
||||
'align_b': str(operation.B.alignment),
|
||||
'split_k_serial': 'false',
|
||||
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
||||
'blas_mode': BlasModeTag[operation.blas_mode]
|
||||
}
|
||||
|
||||
symm_template = self.symm_complex_template if operation.is_complex() else self.symm_template
|
||||
|
||||
return SubstituteTemplate(symm_template, values)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Emitters functions for all targets
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
class EmitSymmConfigurationLibrary:
|
||||
def __init__(self, operation_path, configuration_name):
|
||||
self.configuration_name = configuration_name
|
||||
self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/')
|
||||
|
||||
self.instance_emitter = {
|
||||
SymmKind.Universal: EmitSymmUniversalInstance,
|
||||
}
|
||||
|
||||
self.symm_kind_wrappers = {
|
||||
SymmKind.Universal: 'SymmOperation',
|
||||
}
|
||||
|
||||
self.instance_template = {
|
||||
SymmKind.Universal: """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${symm_kind}<
|
||||
Operation_${operation_name}
|
||||
>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
"""
|
||||
}
|
||||
|
||||
self.header_template = """
|
||||
/*
|
||||
Generated by symm_operation.py - Do not edit.
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "library_internal.h"
|
||||
#include "symm_operation.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
"""
|
||||
|
||||
self.initialize_function_template = """
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void initialize_${configuration_name}(Manifest &manifest) {
|
||||
|
||||
"""
|
||||
self.epilogue_template = """
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
self.configuration_file = open(self.configuration_path, "w")
|
||||
self.configuration_file.write(self.header_template)
|
||||
|
||||
self.instance_definitions = []
|
||||
self.instance_wrappers = []
|
||||
|
||||
self.operations = []
|
||||
return self
|
||||
|
||||
def emit(self, operation):
|
||||
emitter = self.instance_emitter[operation.symm_kind]()
|
||||
|
||||
self.operations.append(operation)
|
||||
|
||||
self.instance_definitions.append(emitter.emit(operation))
|
||||
|
||||
self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.symm_kind], {
|
||||
'configuration_name': self.configuration_name,
|
||||
'operation_name': operation.procedural_name(),
|
||||
'symm_kind': self.symm_kind_wrappers[operation.symm_kind],
|
||||
'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
|
||||
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
|
||||
'compile_guard_end': "#endif" \
|
||||
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
|
||||
}))
|
||||
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
|
||||
# Write instance definitions in top-level namespace
|
||||
for instance_definition in self.instance_definitions:
|
||||
self.configuration_file.write(instance_definition)
|
||||
|
||||
# Add wrapper objects within initialize() function
|
||||
self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
|
||||
'configuration_name': self.configuration_name
|
||||
}))
|
||||
|
||||
for instance_wrapper in self.instance_wrappers:
|
||||
self.configuration_file.write(instance_wrapper)
|
||||
|
||||
self.configuration_file.write(self.epilogue_template)
|
||||
self.configuration_file.close()
|
||||
|
||||
###################################################################################################
|
||||
437
python/cutlass_library/trmm_operation.py
Normal file
437
python/cutlass_library/trmm_operation.py
Normal file
@@ -0,0 +1,437 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utilities for emitting Trmm kernels
|
||||
"""
|
||||
|
||||
import enum
|
||||
import os.path
|
||||
import shutil
|
||||
import functools
|
||||
import operator
|
||||
|
||||
from cutlass_library.library import *
|
||||
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Data structure modeling a TRMM operation
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class TrmmOperation:
|
||||
#
|
||||
def __init__(self, trmm_kind, arch, tile_description, A, B, C, element_epilogue, \
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8):
|
||||
|
||||
self.operation_kind = OperationKind.Trmm
|
||||
self.arch = arch
|
||||
self.tile_description = tile_description
|
||||
self.trmm_kind = trmm_kind
|
||||
self.A = A
|
||||
self.B = B
|
||||
self.C = C
|
||||
self.element_epilogue = element_epilogue
|
||||
self.epilogue_functor = epilogue_functor
|
||||
self.swizzling_functor = swizzling_functor
|
||||
|
||||
#
|
||||
def is_complex(self):
|
||||
complex_operators = [
|
||||
MathOperation.multiply_add_complex,
|
||||
MathOperation.multiply_add_complex_gaussian,
|
||||
MathOperation.multiply_add_complex_fast_f32
|
||||
]
|
||||
return self.tile_description.math_instruction.math_operation in complex_operators
|
||||
return False
|
||||
|
||||
#
|
||||
def is_planar_complex(self):
|
||||
# return self.trmm_kind in (TrmmKind.PlanarComplex, TrmmKind.PlanarComplexArray)
|
||||
return False
|
||||
|
||||
#
|
||||
def accumulator_type(self):
|
||||
accum = self.tile_description.math_instruction.element_accumulator
|
||||
|
||||
if self.is_complex():
|
||||
return get_complex_from_real(accum)
|
||||
|
||||
return accum
|
||||
|
||||
#
|
||||
def short_math_name(self):
|
||||
if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
|
||||
return "g%s" % ShortDataTypeNames[self.accumulator_type()]
|
||||
return ShortDataTypeNames[self.accumulator_type()]
|
||||
|
||||
|
||||
#
|
||||
def core_name(self):
|
||||
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
||||
|
||||
inst_shape = ''
|
||||
inst_operation = ''
|
||||
intermediate_type = ''
|
||||
|
||||
math_operations_map = {
|
||||
MathOperation.xor_popc: 'xor',
|
||||
MathOperation.and_popc: 'and'
|
||||
}
|
||||
|
||||
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \
|
||||
self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp:
|
||||
|
||||
math_op = self.tile_description.math_instruction.math_operation
|
||||
math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
|
||||
|
||||
inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
|
||||
inst_shape += math_op_string
|
||||
|
||||
if self.tile_description.math_instruction.element_a != self.A.element and \
|
||||
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
|
||||
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
||||
|
||||
return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, TrmmKindNames[self.trmm_kind])
|
||||
|
||||
#
|
||||
def extended_name(self):
|
||||
''' Append data types if they differ from compute type. '''
|
||||
if self.is_complex():
|
||||
extended_name = "${core_name}"
|
||||
else:
|
||||
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${element_c}_${core_name}_${element_a}"
|
||||
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||
extended_name = "${core_name}_${element_a}"
|
||||
else:
|
||||
extended_name = "${core_name}"
|
||||
|
||||
extended_name = SubstituteTemplate(extended_name, {
|
||||
'element_a': DataTypeNames[self.A.element],
|
||||
'element_c': DataTypeNames[self.C.element],
|
||||
'core_name': self.core_name()
|
||||
})
|
||||
|
||||
return extended_name
|
||||
|
||||
#
|
||||
def layout_name(self):
|
||||
if self.is_complex() or self.is_planar_complex():
|
||||
return "%s%s" % (
|
||||
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
|
||||
ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)]
|
||||
)
|
||||
return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
|
||||
|
||||
#
|
||||
def side_mode_name(self):
|
||||
return "%s" % (ShortSideModeNames[self.A.side_mode])
|
||||
|
||||
#
|
||||
def fill_mode_name(self):
|
||||
return "%s" % (ShortFillModeNames[self.A.fill_mode])
|
||||
|
||||
#
|
||||
def diag_type_name(self):
|
||||
return "%s" % (ShortDiagTypeNames[self.A.diag_type])
|
||||
|
||||
#
|
||||
def procedural_name(self):
|
||||
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
||||
threadblock = self.tile_description.procedural_name()
|
||||
|
||||
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
||||
|
||||
alignment = max([self.C.alignment])
|
||||
|
||||
return SubstituteTemplate(
|
||||
"cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${side_mode}_${fill_mode}_${diag_type}_align${alignment}",
|
||||
{
|
||||
'opcode_class': opcode_class_name,
|
||||
'extended_name': self.extended_name(),
|
||||
'threadblock': threadblock,
|
||||
'layout': self.layout_name(),
|
||||
'side_mode': self.side_mode_name(),
|
||||
'fill_mode': self.fill_mode_name(),
|
||||
'diag_type': self.diag_type_name(),
|
||||
'alignment': "%d" % self.C.alignment,
|
||||
}
|
||||
)
|
||||
|
||||
#
|
||||
def configuration_name(self):
|
||||
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
||||
return self.procedural_name()
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Emits single instances of a CUTLASS device-wide operator
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
#
|
||||
class EmitTrmmUniversalInstance:
|
||||
''' Responsible for emitting a CUTLASS template definition'''
|
||||
|
||||
def __init__(self):
|
||||
self.trmm_template = """
|
||||
// Trmm operator ${operation_name}
|
||||
using Operation_${operation_name} =
|
||||
typename cutlass::gemm::device::Trmm<
|
||||
${element_a}, ${layout_a},
|
||||
${side_mode}, ${fill_mode}, ${diag_type},
|
||||
${element_b}, ${layout_b},
|
||||
${element_c}, ${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue},
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${align_a},
|
||||
${align_b},
|
||||
${split_k_serial},
|
||||
${math_operation}
|
||||
>;
|
||||
"""
|
||||
self.trmm_complex_template = """
|
||||
// Trmm operator ${operation_name}
|
||||
using Operation_${operation_name} =
|
||||
typename cutlass::gemm::device::Trmm<
|
||||
${element_a}, ${layout_a},
|
||||
${side_mode}, ${fill_mode}, ${diag_type},
|
||||
${element_b}, ${layout_b},
|
||||
${element_c}, ${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor}<
|
||||
${element_c},
|
||||
${epilogue_vector_length},
|
||||
${element_accumulator},
|
||||
${element_epilogue},
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${align_a},
|
||||
${align_b},
|
||||
${split_k_serial},
|
||||
${math_operation},
|
||||
${transform_a}
|
||||
>;
|
||||
"""
|
||||
|
||||
def emit(self, operation):
|
||||
|
||||
threadblock_shape = operation.tile_description.threadblock_shape
|
||||
warp_count = operation.tile_description.warp_count
|
||||
|
||||
warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
|
||||
|
||||
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'element_a': DataTypeTag[operation.A.element],
|
||||
'layout_a': LayoutTag[operation.A.layout],
|
||||
'side_mode' : SideModeTag[operation.A.side_mode],
|
||||
'fill_mode': FillModeTag[operation.A.fill_mode],
|
||||
'diag_type' : DiagTypeTag[operation.A.diag_type],
|
||||
'element_b': DataTypeTag[operation.B.element],
|
||||
'layout_b': LayoutTag[operation.B.layout],
|
||||
'element_c': DataTypeTag[operation.C.element],
|
||||
'layout_c': LayoutTag[operation.C.layout],
|
||||
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
||||
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
||||
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
||||
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
||||
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
||||
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
||||
'warp_shape_m': str(warp_shape[0]),
|
||||
'warp_shape_n': str(warp_shape[1]),
|
||||
'warp_shape_k': str(warp_shape[2]),
|
||||
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
||||
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
||||
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
||||
'epilogue_vector_length': str(epilogue_vector_length),
|
||||
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
||||
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
||||
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
||||
'stages': str(operation.tile_description.stages),
|
||||
'align_a': str(1), # TRMM A's alignment is always 1 for no padding to work until we make zfill work with variable bytes
|
||||
'align_b': str(operation.B.alignment),
|
||||
'split_k_serial': 'false',
|
||||
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
||||
'transform_a': ComplexTransformTag[operation.A.complex_transform]
|
||||
}
|
||||
|
||||
trmm_template = self.trmm_complex_template if operation.is_complex() else self.trmm_template
|
||||
|
||||
return SubstituteTemplate(trmm_template, values)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Emitters functions for all targets
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
class EmitTrmmConfigurationLibrary:
|
||||
def __init__(self, operation_path, configuration_name):
|
||||
self.configuration_name = configuration_name
|
||||
self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/')
|
||||
|
||||
self.instance_emitter = {
|
||||
TrmmKind.Universal: EmitTrmmUniversalInstance,
|
||||
}
|
||||
|
||||
self.trmm_kind_wrappers = {
|
||||
TrmmKind.Universal: 'TrmmOperation',
|
||||
}
|
||||
|
||||
self.instance_template = {
|
||||
TrmmKind.Universal: """
|
||||
${compile_guard_start}
|
||||
manifest.append(new ${trmm_kind}<
|
||||
Operation_${operation_name}
|
||||
>("${operation_name}"));
|
||||
${compile_guard_end}
|
||||
"""
|
||||
}
|
||||
|
||||
self.header_template = """
|
||||
/*
|
||||
Generated by trmm_operation.py - Do not edit.
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "library_internal.h"
|
||||
#include "trmm_operation.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
"""
|
||||
|
||||
self.initialize_function_template = """
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void initialize_${configuration_name}(Manifest &manifest) {
|
||||
|
||||
"""
|
||||
self.epilogue_template = """
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
self.configuration_file = open(self.configuration_path, "w")
|
||||
self.configuration_file.write(self.header_template)
|
||||
|
||||
self.instance_definitions = []
|
||||
self.instance_wrappers = []
|
||||
|
||||
self.operations = []
|
||||
return self
|
||||
|
||||
def emit(self, operation):
|
||||
emitter = self.instance_emitter[operation.trmm_kind]()
|
||||
|
||||
self.operations.append(operation)
|
||||
|
||||
self.instance_definitions.append(emitter.emit(operation))
|
||||
|
||||
self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.trmm_kind], {
|
||||
'configuration_name': self.configuration_name,
|
||||
'operation_name': operation.procedural_name(),
|
||||
'trmm_kind': self.trmm_kind_wrappers[operation.trmm_kind],
|
||||
'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
|
||||
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
|
||||
'compile_guard_end': "#endif" \
|
||||
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
|
||||
}))
|
||||
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
|
||||
# Write instance definitions in top-level namespace
|
||||
for instance_definition in self.instance_definitions:
|
||||
self.configuration_file.write(instance_definition)
|
||||
|
||||
# Add wrapper objects within initialize() function
|
||||
self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
|
||||
'configuration_name': self.configuration_name
|
||||
}))
|
||||
|
||||
for instance_wrapper in self.instance_wrappers:
|
||||
self.configuration_file.write(instance_wrapper)
|
||||
|
||||
self.configuration_file.write(self.epilogue_template)
|
||||
self.configuration_file.close()
|
||||
|
||||
###################################################################################################
|
||||
Reference in New Issue
Block a user