mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
Rename python/cutlass to python/cutlass_cppgen (#2652)
This commit is contained in:
34
python/cutlass_cppgen/backend/evt/__init__.py
Normal file
34
python/cutlass_cppgen/backend/evt/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_cppgen.backend.evt.epilogue import EpilogueFunctorVisitor
|
||||
from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend
|
||||
38
python/cutlass_cppgen/backend/evt/backend/__init__.py
Normal file
38
python/cutlass_cppgen/backend/evt/backend/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_cppgen.backend.evt.backend.sm80_emitter import Sm80Emitter
|
||||
import cutlass_cppgen.backend.evt.backend.sm80_nodes as sm80_nodes
|
||||
from cutlass_cppgen.backend.evt.backend.sm90_emitter import Sm90Emitter
|
||||
import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes
|
||||
from cutlass_cppgen.backend.evt.backend.sm100_emitter import Sm100Emitter
|
||||
import cutlass_cppgen.backend.evt.backend.sm100_nodes as sm100_nodes
|
||||
159
python/cutlass_cppgen/backend/evt/backend/emitter_base.py
Normal file
159
python/cutlass_cppgen/backend/evt/backend/emitter_base.py
Normal file
@@ -0,0 +1,159 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Base class for Epilogue Visitor Emitter
|
||||
"""
|
||||
|
||||
from cutlass_library import DataTypeTag
|
||||
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR
|
||||
|
||||
|
||||
class FusionCallbacks:
|
||||
def __init__(self, dag_ir: DAGIR, cc: int, emit_CD=True) -> None:
|
||||
"""
|
||||
Emit the EVT fusion callbacks
|
||||
:param dag_ir: the DAG IR holding the epilogue visitor
|
||||
:param cc: compute capability
|
||||
:param emit_CD: whether to emit nodes C & D as a part of the fusion callbacks
|
||||
For Sm90, set emit_CD=False, as Tensor C & D are hardcoded in the collective API
|
||||
so that their shared memory can be explicitly reused
|
||||
For Sm89, set emit_CD=True as they are treated as normal AuxLoad & AuxStore nodes.
|
||||
"""
|
||||
self.dag_ir = dag_ir
|
||||
self.emit_CD = emit_CD
|
||||
self.cc = cc
|
||||
self.evt_cc = 90 if cc >= 90 else cc
|
||||
if self.cc < 90:
|
||||
self.namespace = "threadblock"
|
||||
else:
|
||||
self.namespace = "fusion"
|
||||
|
||||
#
|
||||
# Helper functions
|
||||
#
|
||||
|
||||
def get_visitor_name(self, node: str):
|
||||
"""
|
||||
Get the visitor name
|
||||
"""
|
||||
meta = self.dag_ir.get_node_meta(node)
|
||||
if not isinstance(meta, TopoVisitorNode) and self.dag_ir.in_degree(node) > 0:
|
||||
return f"EVT{meta.name_camel}"
|
||||
else:
|
||||
return meta.name_camel
|
||||
|
||||
def emit(self):
|
||||
node_metas = self.dag_ir.node_metas_topological_order()
|
||||
epilogue_str = ""
|
||||
# Step 1: emit individual node type decl
|
||||
# emit the EVT & DAG connector
|
||||
for meta in node_metas:
|
||||
if not meta.disabled:
|
||||
epilogue_str += self.emit_node(meta)
|
||||
if not self.emit_CD and meta.name == "D":
|
||||
continue
|
||||
if isinstance(meta, TopoVisitorNode):
|
||||
epilogue_str += self.emit_dag(meta)
|
||||
else:
|
||||
epilogue_str += self.emit_evt(meta)
|
||||
|
||||
# Step 2: post-processing & get callback name
|
||||
if not self.emit_CD:
|
||||
if not self.dag_ir.has_node("C"):
|
||||
epilogue_str += "using ElementC = void;\nusing StrideC = StrideD;\n"
|
||||
output_node = self.dag_ir.get_all_inputs("D")[0]
|
||||
# The callback is the src of node D
|
||||
callback_name = self.get_visitor_name(output_node)
|
||||
else:
|
||||
# The callback is the last node in the topological order
|
||||
callback_name = self.get_visitor_name(node_metas[-1].name)
|
||||
return epilogue_str, callback_name
|
||||
|
||||
def emit_evt(self, node):
|
||||
if self.dag_ir.in_degree(node.name) == 0:
|
||||
return ""
|
||||
|
||||
evt_tmp = f"""
|
||||
using EVT{node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}EVT<
|
||||
{node.name_camel},
|
||||
"""
|
||||
sorted_children = self.dag_ir.get_all_inputs(node.name)
|
||||
evt_node_strs = [f" {self.get_visitor_name(child_name)}" for child_name in sorted_children]
|
||||
evt_tmp += ",\n".join(evt_node_strs) + ">;\n"
|
||||
|
||||
return evt_tmp
|
||||
|
||||
def emit_dag(self, node):
|
||||
subgraph = node.subgraph
|
||||
subgraph_nodes = subgraph.nodes_topological_order()
|
||||
# Emit the Edge Tuple
|
||||
edge_tuples = "cute::tuple<\n"
|
||||
for n in subgraph_nodes[:-1]:
|
||||
in_edges = subgraph.in_edges(n)
|
||||
edge_weights = [subgraph.get_edge_weight(edge[0], edge[1]) for edge in in_edges]
|
||||
sorted_children = [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))]
|
||||
edge_tuple = " cute::seq<"
|
||||
edge_str = [str(subgraph_nodes.index(child)) for child in sorted_children]
|
||||
edge_tuple += ", ".join(edge_str) + ">,\n"
|
||||
|
||||
edge_tuples += edge_tuple
|
||||
edge_tuples += " >"
|
||||
|
||||
# Emit the node list
|
||||
dag_nodes = ""
|
||||
dag_node_strs = []
|
||||
for n in subgraph_nodes[:-1]:
|
||||
n_meta = subgraph.get_node_meta(n)
|
||||
if n_meta.disabled:
|
||||
dag_node_strs.append(f" {self.get_visitor_name(n)}")
|
||||
else:
|
||||
dag_node_strs.append(f" {n_meta.name_camel}")
|
||||
dag_nodes = ",\n".join(dag_node_strs)
|
||||
|
||||
return f"""
|
||||
using {node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}TopologicalVisitor<
|
||||
{DataTypeTag[node.subgraph.element_compute]},
|
||||
{edge_tuples},
|
||||
{dag_nodes}
|
||||
>;
|
||||
"""
|
||||
|
||||
def emit_node(self, node):
|
||||
if isinstance(node, TopoVisitorNode):
|
||||
emission = ""
|
||||
for node in node.subgraph.node_metas_topological_order():
|
||||
if not node.disabled:
|
||||
emission += self.emit_node(node)
|
||||
return emission
|
||||
else:
|
||||
return node.underlying_impl.type_decl
|
||||
116
python/cutlass_cppgen/backend/evt/backend/sm100_emitter.py
Normal file
116
python/cutlass_cppgen/backend/evt/backend/sm100_emitter.py
Normal file
@@ -0,0 +1,116 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Emitter for Sm100 Epilogue Visitor
|
||||
"""
|
||||
|
||||
from cutlass_library import DataType, DataTypeTag, EpilogueScheduleTag, OpcodeClassTag
|
||||
from cutlass_cppgen.backend.library import to_blackwell_threadblock_shape
|
||||
from cutlass_cppgen.backend import GemmOperationUniversal
|
||||
from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
|
||||
from cutlass_cppgen.backend.evt.ir.node import TupleEmitter
|
||||
|
||||
|
||||
class Sm100CollectiveEpilogue:
|
||||
def __init__(self, tile_description,
|
||||
kernel_schedule,
|
||||
epilogue_schedule,
|
||||
element_accumulator,
|
||||
element_d,
|
||||
fusion_callbacks) -> None:
|
||||
|
||||
self.cta_tile_mnk, _ = to_blackwell_threadblock_shape(tile_description, tile_description.cluster_shape, kernel_schedule)
|
||||
self.element_accumulator = element_accumulator
|
||||
if fusion_callbacks.dag_ir.has_node("C"):
|
||||
self.element_c = fusion_callbacks.dag_ir.get_node_meta("C").element
|
||||
else:
|
||||
self.element_c = DataType.void
|
||||
self.element_d = element_d
|
||||
self.schedule = epilogue_schedule
|
||||
self.fusion_callbacks = fusion_callbacks
|
||||
self.opclass = tile_description.math_instruction.opcode_class
|
||||
|
||||
@property
|
||||
def CtaTileMNK(self) -> str:
|
||||
"""
|
||||
The threadblock shape
|
||||
"""
|
||||
return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>"
|
||||
|
||||
@property
|
||||
def EpilogueTileType(self) -> str:
|
||||
"""
|
||||
The epilogue tile type
|
||||
"""
|
||||
return "cutlass::epilogue::collective::EpilogueTileAuto"
|
||||
|
||||
@property
|
||||
def Schedule(self) -> str:
|
||||
return EpilogueScheduleTag[self.schedule]
|
||||
|
||||
def emit(self):
|
||||
tuple_emitter = TupleEmitter("int64_t")
|
||||
stride_D_str = self.fusion_callbacks.dag_ir.get_node_meta("D").underlying_impl.stride_mnl
|
||||
stride_C_str = stride_D_str
|
||||
if self.fusion_callbacks.dag_ir.has_node("C"):
|
||||
stride_C_str = self.fusion_callbacks.dag_ir.get_node_meta("C").underlying_impl.stride_mnl
|
||||
|
||||
callback_decl, callback_name = self.fusion_callbacks.emit()
|
||||
return callback_name, f"""
|
||||
using EpilogueDescriptor = cutlass::epilogue::collective::detail::Sm100EpilogueDescriptor<
|
||||
{OpcodeClassTag[self.opclass]},
|
||||
{self.CtaTileMNK}, {self.EpilogueTileType},
|
||||
{DataTypeTag[self.element_accumulator]}, {DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]},
|
||||
{self.Schedule}, {stride_C_str}, {stride_D_str},
|
||||
false /* IsPerColScaleSupported */,
|
||||
false /* IsBlockScaleSupported */
|
||||
>;
|
||||
{callback_decl}
|
||||
"""
|
||||
|
||||
|
||||
class Sm100Emitter:
|
||||
def __init__(self, operation: GemmOperationUniversal, graph) -> None:
|
||||
fusion_callbacks = FusionCallbacks(graph, cc=100, emit_CD=False)
|
||||
|
||||
self.collective_epilogue = Sm100CollectiveEpilogue(
|
||||
tile_description=operation.tile_description,
|
||||
kernel_schedule=operation.tile_description.kernel_schedule,
|
||||
epilogue_schedule=operation.tile_description.epilogue_schedule,
|
||||
element_accumulator=operation.tile_description.math_instruction.element_accumulator,
|
||||
element_d=fusion_callbacks.dag_ir.get_node_meta("D").element,
|
||||
fusion_callbacks=fusion_callbacks
|
||||
)
|
||||
|
||||
def emit(self):
|
||||
return self.collective_epilogue.emit()
|
||||
134
python/cutlass_cppgen/backend/evt/backend/sm100_nodes.py
Normal file
134
python/cutlass_cppgen/backend/evt/backend/sm100_nodes.py
Normal file
@@ -0,0 +1,134 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from pycute import product
|
||||
|
||||
from cutlass_library import DataTypeSize, DataTypeTag
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir import AuxLoadImpl, AuxStoreImpl
|
||||
import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes
|
||||
|
||||
from cutlass_cppgen.backend.library import FloatRoundStyleTag
|
||||
|
||||
|
||||
Sm100AccumulatorImpl = sm90_nodes.Sm90AccumulatorImpl
|
||||
Sm100LoadSrcImpl = sm90_nodes.Sm90LoadSrcImpl
|
||||
Sm100ScalarBroadcastImpl = sm90_nodes.Sm90ScalarBroadcastImpl
|
||||
Sm100RowBroadcastImpl = sm90_nodes.Sm90RowBroadcastImpl
|
||||
Sm100ColumnBroadcastImpl = sm90_nodes.Sm90ColumnBroadcastImpl
|
||||
Sm100ComputeImpl = sm90_nodes.Sm90ComputeImpl
|
||||
Sm100StoreDImpl = sm90_nodes.Sm90StoreDImpl
|
||||
Sm100ColumnReductionImpl = sm90_nodes.Sm90ColumnReductionImpl
|
||||
Sm100RowReductionImpl = sm90_nodes.Sm90RowReductionImpl
|
||||
Sm100ScalarReductionImpl = sm90_nodes.Sm90ScalarReductionImpl
|
||||
|
||||
|
||||
class Sm100AuxLoadImpl(AuxLoadImpl):
|
||||
|
||||
@property
|
||||
def descriptor(self) -> str:
|
||||
"""
|
||||
Descriptor for Aux Load
|
||||
"""
|
||||
return f"{self.name_camel}Descriptor"
|
||||
|
||||
def decl_descriptor(self) -> str:
|
||||
"""
|
||||
Declare the descriptor type
|
||||
"""
|
||||
return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxLoadDescriptor<EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}>;\n"
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = self.decl_descriptor()
|
||||
self._type_decl += f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad<
|
||||
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
|
||||
{self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
|
||||
"""
|
||||
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
|
||||
"""
|
||||
return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128)
|
||||
|
||||
|
||||
class Sm100AuxStoreImpl(AuxStoreImpl):
|
||||
|
||||
@property
|
||||
def descriptor(self) -> str:
|
||||
"""
|
||||
Descriptor for Aux Load
|
||||
"""
|
||||
return f"{self.name_camel}Descriptor"
|
||||
|
||||
def decl_descriptor(self) -> str:
|
||||
"""
|
||||
Declare the descriptor type
|
||||
"""
|
||||
return f"""
|
||||
using {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxStoreDescriptor<
|
||||
EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}
|
||||
>;
|
||||
"""
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = self.decl_descriptor()
|
||||
self._type_decl += f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore<
|
||||
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
|
||||
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom,
|
||||
typename {self.descriptor}::CopyOpR2S
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
|
||||
"""
|
||||
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
|
||||
"""
|
||||
return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128)
|
||||
47
python/cutlass_cppgen/backend/evt/backend/sm80_emitter.py
Normal file
47
python/cutlass_cppgen/backend/evt/backend/sm80_emitter.py
Normal file
@@ -0,0 +1,47 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Emitter for Sm80 Epilogue Visitor
|
||||
"""
|
||||
|
||||
from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
|
||||
from cutlass_cppgen.backend import GemmOperationUniversal
|
||||
|
||||
|
||||
class Sm80Emitter:
|
||||
def __init__(self, operation: GemmOperationUniversal, graph) -> None:
|
||||
self.fusion_callbacks = FusionCallbacks(graph, cc=80)
|
||||
|
||||
def emit(self):
|
||||
callback_decl, callback_name = self.fusion_callbacks.emit()
|
||||
return callback_name, callback_decl
|
||||
258
python/cutlass_cppgen/backend/evt/backend/sm80_nodes.py
Normal file
258
python/cutlass_cppgen/backend/evt/backend/sm80_nodes.py
Normal file
@@ -0,0 +1,258 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_library import DataTypeSize, DataTypeTag
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir import (
|
||||
# Load Node
|
||||
AccumulatorImpl,
|
||||
AuxLoadImpl,
|
||||
ColumnBroadcastImpl,
|
||||
LoadNode,
|
||||
LoadSrcImpl,
|
||||
RowBroadcastImpl,
|
||||
ScalarBroadcastImpl,
|
||||
# Compute Node
|
||||
ComputeImpl,
|
||||
# Store Node
|
||||
AuxStoreImpl,
|
||||
ColumnReductionImpl,
|
||||
RowReductionImpl,
|
||||
ScalarReductionImpl
|
||||
)
|
||||
|
||||
from cutlass_cppgen.backend.library import (
|
||||
FloatRoundStyleTag,
|
||||
FunctionalOp,
|
||||
op_tag,
|
||||
)
|
||||
|
||||
|
||||
class Sm80AccumulatorImpl(AccumulatorImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::threadblock::VisitorAccFetch;\n"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80AuxLoadImpl(AuxLoadImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxLoad<
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]}, {self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80LoadSrcImpl(Sm80AuxLoadImpl):
|
||||
pass
|
||||
|
||||
|
||||
class Sm80ScalarBroadcastImpl(ScalarBroadcastImpl):
|
||||
def __init__(self, node: LoadNode) -> None:
|
||||
super().__init__(node)
|
||||
self.broadcast_count = 1
|
||||
self.reduction_fn = FunctionalOp.Multiplies
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarBroadcast<
|
||||
{DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80RowBroadcastImpl(RowBroadcastImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80ColumnBroadcastImpl(ColumnBroadcastImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColBroadcast<
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80ComputeImpl(ComputeImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
{op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]},
|
||||
{FloatRoundStyleTag[self.round_style]}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80AuxStoreImpl(AuxStoreImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxStore<
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80StoreDImpl(Sm80AuxStoreImpl):
|
||||
pass
|
||||
|
||||
|
||||
class Sm80ColumnReductionImpl(ColumnReductionImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]},
|
||||
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80RowReductionImpl(RowReductionImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]},
|
||||
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80ScalarReductionImpl(ScalarReductionImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]},
|
||||
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
98
python/cutlass_cppgen/backend/evt/backend/sm90_emitter.py
Normal file
98
python/cutlass_cppgen/backend/evt/backend/sm90_emitter.py
Normal file
@@ -0,0 +1,98 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Emitter for Sm90 Epilogue Visitor
|
||||
"""
|
||||
|
||||
from cutlass_library import DataTypeTag, EpilogueScheduleTag
|
||||
from cutlass_cppgen.backend import GemmOperationUniversal
|
||||
from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
|
||||
|
||||
|
||||
class CollectiveEpilogue:
|
||||
def __init__(self, tile_description,
|
||||
schedule,
|
||||
element_c,
|
||||
element_d,
|
||||
fusion_callbacks) -> None:
|
||||
|
||||
self.cta_tile_mnk = tile_description.threadblock_shape
|
||||
self.element_c = element_c
|
||||
self.element_d = element_d
|
||||
self.schedule = schedule
|
||||
self.fusion_callbacks = fusion_callbacks
|
||||
|
||||
@property
|
||||
def CtaTileMNK(self) -> str:
|
||||
"""
|
||||
The threadblock shape
|
||||
"""
|
||||
return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>"
|
||||
|
||||
@property
|
||||
def EpilogueTileType(self) -> str:
|
||||
"""
|
||||
The epilogue tile type
|
||||
"""
|
||||
return "cutlass::epilogue::collective::EpilogueTileAuto"
|
||||
|
||||
@property
|
||||
def Schedule(self) -> str:
|
||||
return EpilogueScheduleTag[self.schedule]
|
||||
|
||||
def emit(self):
|
||||
callback_decl, callback_name = self.fusion_callbacks.emit()
|
||||
return callback_name, f"""
|
||||
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
{self.CtaTileMNK}, {self.EpilogueTileType},
|
||||
{DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]},
|
||||
{self.Schedule}
|
||||
>;
|
||||
{callback_decl}
|
||||
"""
|
||||
|
||||
|
||||
class Sm90Emitter:
|
||||
def __init__(self, operation: GemmOperationUniversal, graph) -> None:
|
||||
fusion_callbacks = FusionCallbacks(graph, cc=90, emit_CD=False)
|
||||
|
||||
self.collective_epilogue = CollectiveEpilogue(
|
||||
tile_description=operation.tile_description,
|
||||
schedule=operation.tile_description.epilogue_schedule,
|
||||
element_c=operation.C.element,
|
||||
element_d=operation.C.element,
|
||||
fusion_callbacks=fusion_callbacks
|
||||
)
|
||||
|
||||
def emit(self):
|
||||
return self.collective_epilogue.emit()
|
||||
329
python/cutlass_cppgen/backend/evt/backend/sm90_nodes.py
Normal file
329
python/cutlass_cppgen/backend/evt/backend/sm90_nodes.py
Normal file
@@ -0,0 +1,329 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from pycute import product
|
||||
|
||||
from cutlass_library import DataTypeSize, DataTypeTag
|
||||
from cutlass_cppgen.backend.evt.ir import (
|
||||
# Load Node
|
||||
AccumulatorImpl,
|
||||
AuxLoadImpl,
|
||||
ColumnBroadcastImpl,
|
||||
LoadNode,
|
||||
LoadSrcImpl,
|
||||
RowBroadcastImpl,
|
||||
ScalarBroadcastImpl,
|
||||
# Compute Node
|
||||
ComputeImpl,
|
||||
ComputeNode,
|
||||
# Store Node
|
||||
AuxStoreImpl,
|
||||
ColumnReductionImpl,
|
||||
RowReductionImpl,
|
||||
ScalarReductionImpl,
|
||||
StoreNode,
|
||||
StoreDImpl,
|
||||
)
|
||||
from cutlass_cppgen.backend.library import (
|
||||
FloatRoundStyleTag,
|
||||
FunctionalOp,
|
||||
op_tag,
|
||||
)
|
||||
|
||||
|
||||
class Sm90AccumulatorImpl(AccumulatorImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::fusion::Sm90AccFetch;\n"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90LoadSrcImpl(LoadSrcImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using ElementC = {DataTypeTag[self.element]};
|
||||
using StrideC = {self.stride_mnl};
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90SrcFetch<{DataTypeTag[self.element]}>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90AuxLoadImpl(AuxLoadImpl):
|
||||
|
||||
@property
|
||||
def descriptor(self) -> str:
|
||||
"""
|
||||
Descriptor for Aux Load
|
||||
"""
|
||||
return f"{self.name_camel}Descriptor"
|
||||
|
||||
def decl_descriptor(self) -> str:
|
||||
"""
|
||||
Declare the descriptor type
|
||||
"""
|
||||
return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::AuxLoadDescriptor<EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}>;\n"
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = self.decl_descriptor()
|
||||
self._type_decl += f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad<
|
||||
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
|
||||
{self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
|
||||
"""
|
||||
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
|
||||
"""
|
||||
return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128)
|
||||
|
||||
|
||||
class Sm90ScalarBroadcastImpl(ScalarBroadcastImpl):
|
||||
def __init__(self, node: LoadNode) -> None:
|
||||
super().__init__(node)
|
||||
self.broadcast_count = 1
|
||||
self.reduction_fn = FunctionalOp.Multiplies
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarBroadcast<
|
||||
{DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90RowBroadcastImpl(RowBroadcastImpl):
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90ColumnBroadcastImpl(ColumnBroadcastImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90ComputeImpl(ComputeImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90Compute<
|
||||
{op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]},
|
||||
{FloatRoundStyleTag[self.round_style]}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90AuxStoreImpl(AuxStoreImpl):
|
||||
|
||||
@property
|
||||
def descriptor(self) -> str:
|
||||
"""
|
||||
Descriptor for Aux Load
|
||||
"""
|
||||
return f"{self.name_camel}Descriptor"
|
||||
|
||||
def decl_descriptor(self) -> str:
|
||||
"""
|
||||
Declare the descriptor type
|
||||
"""
|
||||
return f"""
|
||||
using {self.descriptor} = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
||||
EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}
|
||||
>;
|
||||
"""
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = self.decl_descriptor()
|
||||
self._type_decl += f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore<
|
||||
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
|
||||
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom,
|
||||
typename {self.descriptor}::CopyOpR2S
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
|
||||
"""
|
||||
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
|
||||
"""
|
||||
return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128)
|
||||
|
||||
|
||||
class Sm90StoreDImpl(StoreDImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
return f"""
|
||||
using ElementD = {DataTypeTag[self.element]};
|
||||
using StrideD = {self.stride_mnl};
|
||||
"""
|
||||
|
||||
|
||||
class Sm90ColumnReductionImpl(ColumnReductionImpl):
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0,
|
||||
typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
|
||||
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90RowReductionImpl(RowReductionImpl):
|
||||
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0 /* Stages */,
|
||||
typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
|
||||
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90ScalarReductionImpl(ScalarReductionImpl):
|
||||
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
|
||||
{DataTypeTag[self.element]}, {DataTypeTag[self.element_compute]},
|
||||
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
168
python/cutlass_cppgen/backend/evt/epilogue.py
Normal file
168
python/cutlass_cppgen/backend/evt/epilogue.py
Normal file
@@ -0,0 +1,168 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Epilogue Visitor interface for compiling, and running visitor-based epilogue.
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
from cutlass_library import DataType
|
||||
import numpy as np
|
||||
|
||||
from cutlass_cppgen.backend.epilogue import EpilogueFunctorBase
|
||||
import cutlass_cppgen.backend.evt.backend
|
||||
from cutlass_cppgen.backend.frontend import TensorFrontend
|
||||
from cutlass_cppgen.utils.datatypes import is_numpy_tensor
|
||||
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
||||
|
||||
|
||||
class EpilogueFunctorVisitor(EpilogueFunctorBase):
|
||||
"""
|
||||
Apply an epilogue functor described by the epilogue EVT
|
||||
|
||||
:param cc: compute capability
|
||||
:param visitor_frontend: user-provide visitor frontend
|
||||
|
||||
"""
|
||||
def __init__(self, cc: int, visitor, element_compute=DataType.f32) -> None:
|
||||
# Type of Emitter based on CC
|
||||
self.emit_cls = getattr(cutlass_cppgen.backend.evt.backend, f"Sm{cc_map[cc]}Emitter")
|
||||
|
||||
# Visitor Types
|
||||
self.visitor = visitor
|
||||
self.graph = visitor.dag_ir
|
||||
|
||||
# Data types
|
||||
self.element_epilogue = element_compute # element compute
|
||||
self.element_output = self.graph.get_node_meta('D').underlying_impl.element
|
||||
|
||||
# Epilogue Thread Type
|
||||
epilogue_thread_type = self.visitor.epilogue_thread_type
|
||||
if cc_map[cc] in [90, 100]:
|
||||
self.arg_c_type = self.visitor.arg_c_type
|
||||
self.arg_d_type = self.visitor.arg_d_type
|
||||
output_names = self.visitor.return_names
|
||||
reduction_names = self.visitor.reduction_names
|
||||
|
||||
# Epilogue stages specialized for sm80 kernel
|
||||
if cc == 80:
|
||||
if hasattr(self.visitor, "epilogue_stages"):
|
||||
self.epilogue_stages = self.visitor.epilogue_stages
|
||||
assert self.epilogue_stages <= 2, "Only supports Stages <=2 in SM80 Epilogue"
|
||||
|
||||
# Epilogue Argument Type
|
||||
class _Arguments(ctypes.Structure):
|
||||
"""
|
||||
Concepts:
|
||||
class _EpilogueArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("epilogue", _Arguments), <- this class
|
||||
("ptr_C", ctypes.c_void_p),
|
||||
("stride_C", StrideBatched_),
|
||||
("ptr_D", ctypes.c_void_p),
|
||||
("stride_D", StrideBatched_)
|
||||
]
|
||||
"""
|
||||
_fields_ = [
|
||||
("output_op", epilogue_thread_type)
|
||||
]
|
||||
|
||||
def __init__(self, kwargs: dict) -> None:
|
||||
# The user-input kwargs is a dict of (name: tensors)
|
||||
# We first convert all of them to device pointers
|
||||
ptr_kwargs = {}
|
||||
for key in kwargs.keys():
|
||||
is_output = key in output_names and key not in reduction_names
|
||||
ptr_kwargs[key] = self.get_tensor_ptr(key, kwargs, is_output)
|
||||
# Initialize the thread arguments
|
||||
self.output_op = epilogue_thread_type(ptr_kwargs)
|
||||
|
||||
def get_tensor_ptr(self, tensor_name, kwargs, is_output=False):
|
||||
"""
|
||||
Helper function for extracting device pointer
|
||||
"""
|
||||
# Skip the special tensors
|
||||
if cc in [90, 100]:
|
||||
if tensor_name in ["C", "D"]:
|
||||
return 0
|
||||
if tensor_name not in kwargs.keys():
|
||||
raise ValueError(f"Tensor {tensor_name} is not provided.")
|
||||
tensor = kwargs[tensor_name]
|
||||
|
||||
# For float scalar constant, directly return the value
|
||||
if isinstance(tensor, float):
|
||||
return tensor
|
||||
|
||||
# The tensor frontend returns a device buffer for np.ndarray
|
||||
# and device ptr for other frontends
|
||||
buffer_or_ptr = TensorFrontend.argument(tensor, is_output)
|
||||
if is_numpy_tensor(tensor):
|
||||
# Remember the host tensor for later synchronization
|
||||
setattr(self, f"{tensor_name}_buffer", buffer_or_ptr)
|
||||
setattr(self, f"{tensor_name}_host", tensor)
|
||||
return int(buffer_or_ptr.ptr)
|
||||
else:
|
||||
return int(buffer_or_ptr)
|
||||
|
||||
def sync(self):
|
||||
"""
|
||||
Synchronize the results from device to host
|
||||
"""
|
||||
for name in output_names:
|
||||
if hasattr(self, f"{name}_host"):
|
||||
host_tensor = getattr(self, f"{name}_host")
|
||||
tensor_ptr = getattr(self, f"{name}_buffer").ptr
|
||||
(err,) = cuda.cuMemcpyDtoH(
|
||||
host_tensor,
|
||||
tensor_ptr,
|
||||
host_tensor.size * host_tensor.itemsize,
|
||||
)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
self.epilogue_type = _Arguments
|
||||
|
||||
def emit(self, operation):
|
||||
"""
|
||||
Emit the C++ code
|
||||
"""
|
||||
emitter = self.emit_cls(operation, self.graph)
|
||||
return emitter.emit()
|
||||
|
||||
def get_smem_size(self, tile_description):
|
||||
"""
|
||||
Get the shared memory size in bytes
|
||||
"""
|
||||
return self.visitor.get_smem_size(tile_description)
|
||||
33
python/cutlass_cppgen/backend/evt/frontend/__init__.py
Normal file
33
python/cutlass_cppgen/backend/evt/frontend/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_cppgen.backend.evt.frontend.python_ast import PythonASTFrontend
|
||||
272
python/cutlass_cppgen/backend/evt/frontend/frontend_base.py
Normal file
272
python/cutlass_cppgen/backend/evt/frontend/frontend_base.py
Normal file
@@ -0,0 +1,272 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Base class for Python EVT Frontend
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
|
||||
from cutlass_library import DataType
|
||||
from cutlass_cppgen.backend.evt.ir import (
|
||||
ComputeNode,
|
||||
DAGIR,
|
||||
LayoutNode,
|
||||
LoadNode,
|
||||
StoreNode,
|
||||
)
|
||||
from cutlass_cppgen.backend.evt.passes import (
|
||||
EVTGraphDrawer,
|
||||
EVTPassManager,
|
||||
GetSmemSize,
|
||||
PassDAG2Tree,
|
||||
PassGetArgumentType,
|
||||
PassGetImpl,
|
||||
PassFixElementD,
|
||||
PassLayoutManipulateElimination,
|
||||
PassPreprocessRed,
|
||||
PassShapeTypePropagation,
|
||||
)
|
||||
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
||||
from cutlass_cppgen.backend.utils import device_cc
|
||||
from cutlass_cppgen.epilogue.evt_ops import permute, reshape
|
||||
from cutlass_cppgen.utils.datatypes import library_type
|
||||
|
||||
|
||||
class EVTFrontendBase:
|
||||
layout_fns = {
|
||||
"permute": permute,
|
||||
"reshape": reshape
|
||||
}
|
||||
|
||||
def __init__(self, cc, element_compute=DataType.f32, additional_passes=[], **kwargs) -> None:
|
||||
self.cc = cc
|
||||
self.element_compute = library_type(element_compute)
|
||||
self.dag_ir = DAGIR(self.cc, self.element_compute)
|
||||
self.compute_cnt = 0
|
||||
self.layout_cnt = 0
|
||||
self.imm_cnt = 0
|
||||
|
||||
self.pass_manager = EVTPassManager(
|
||||
self.dag_ir,
|
||||
[
|
||||
PassPreprocessRed,
|
||||
PassGetArgumentType,
|
||||
PassShapeTypePropagation,
|
||||
PassLayoutManipulateElimination,
|
||||
PassGetImpl,
|
||||
PassDAG2Tree,
|
||||
PassFixElementD
|
||||
] + additional_passes)
|
||||
|
||||
if self.cc == 80:
|
||||
self._epilogue_stages = 1
|
||||
else:
|
||||
self._epilogue_stages = None
|
||||
|
||||
@property
|
||||
def epilogue_stages(self):
|
||||
return self._epilogue_stages
|
||||
|
||||
@epilogue_stages.setter
|
||||
def epilogue_stages(self, stages):
|
||||
self._epilogue_stages = stages
|
||||
|
||||
|
||||
def parse(self, *args, **kwargs):
|
||||
raise NotImplementedError(f"The 'parse' function must be overloaded in frontend class")
|
||||
|
||||
def trace(self, *args, **kwargs):
|
||||
# Parse the input
|
||||
self.parse(*args, **kwargs)
|
||||
|
||||
# Verify the DAG IR to ensure that "D" is the output node with out_degree = 0
|
||||
if (self.cc >= 90):
|
||||
if (self.dag_ir.out_degree("D") != 0):
|
||||
raise RuntimeError(
|
||||
f"On SM90 or higher, D is expected to be a output node with 0 users to "
|
||||
f"enable smem reuse between C and D, but got {self.dag_ir.out_degree('D')}")
|
||||
|
||||
# Run the passes
|
||||
self.pass_manager()
|
||||
# Set the epilogue type
|
||||
self.epilogue_thread_type = self.dag_ir.epilogue_thread_type
|
||||
if cc_map[self.cc] in [90, 100]:
|
||||
self.arg_c_type = self.dag_ir.arg_c_type
|
||||
self.arg_d_type = self.dag_ir.arg_d_type
|
||||
self.reduction_names = self.dag_ir.reduction_names
|
||||
|
||||
#
|
||||
# Helper functions for DAG IR manipulation
|
||||
#
|
||||
|
||||
def add_node(self, node):
|
||||
self.dag_ir.add_node(node)
|
||||
|
||||
def add_edge(self, src, tgt, weight=0):
|
||||
self.dag_ir.add_edge(src, tgt, weight=weight)
|
||||
|
||||
def set_tensor(self, node_name, example):
|
||||
"""
|
||||
Add an example tensor to node {node_name} in the DAG IR
|
||||
"""
|
||||
meta = self.dag_ir.get_node_meta(node_name)
|
||||
meta.tensor = {"tensor": example}
|
||||
|
||||
def set_store_tensor(self, node_name, example):
|
||||
"""
|
||||
Add an example tensor to node {node_name} in the DAG IR
|
||||
"""
|
||||
meta = self.dag_ir.get_node_meta(node_name)
|
||||
meta.store_tensor = {"tensor": example}
|
||||
|
||||
def mark_output(self, node_name):
|
||||
"""
|
||||
Mark a store node as output
|
||||
"""
|
||||
meta = self.dag_ir.get_node_meta(node_name)
|
||||
if not isinstance(meta, StoreNode):
|
||||
raise ValueError(
|
||||
f"Only StoreNodes can be marked as output. "
|
||||
f"Got {type(meta).__name__}: {node_name}")
|
||||
meta.is_output = True
|
||||
|
||||
# Add node with specific type
|
||||
|
||||
def add_load_node(self, name, example):
|
||||
"""
|
||||
Add a Load node to DAG IR
|
||||
:param name: name of the loaded variable
|
||||
:type name: str
|
||||
:param example: example input
|
||||
:type example: np.ndarray|torch.Tensor|cupy.ndarray|float
|
||||
"""
|
||||
if name is None:
|
||||
raise ValueError(f"Name is not provided.")
|
||||
if example is None:
|
||||
raise ValueError(f"Example input for {name} is not provided.")
|
||||
load_node = LoadNode(name)
|
||||
load_node.tensor = {"tensor": example}
|
||||
# Special logics for accumulator
|
||||
if name == "accum":
|
||||
if load_node.tensor.rank == 2:
|
||||
new_shape = tuple([1, ] + list(load_node.tensor.shape))
|
||||
load_node.tensor.broadcast(new_shape)
|
||||
elif load_node.tensor.rank < 2 or load_node.tensor.rank > 3:
|
||||
raise ValueError(f"Expect example inputs for 'accum' be a rank-2 or rank-3 tensor. Got {load_node.tensor.shape}.")
|
||||
self.add_node(load_node)
|
||||
|
||||
def add_imm(self, value: Union[float,int]):
|
||||
"""
|
||||
Add an immediate scalar value to DAG IR
|
||||
:param value: the value of the immediate scalar
|
||||
:type value: float
|
||||
"""
|
||||
try:
|
||||
value = float(value)
|
||||
except:
|
||||
raise ValueError(f"{type(value).__name__} cannot be converted to float.")
|
||||
|
||||
name = f"imm_{value}_k{self.imm_cnt}".replace('.', '_')
|
||||
self.imm_cnt += 1
|
||||
load_node = LoadNode(name)
|
||||
load_node.tensor = {"tensor": value, "is_constant": True}
|
||||
self.add_node(load_node)
|
||||
return name
|
||||
|
||||
def add_compute_node(self, op, name=None):
|
||||
"""
|
||||
Add a compute node.
|
||||
:param op: the computation op
|
||||
:param name: the node name (optional)
|
||||
:type name: str
|
||||
:return: the name of the compute node
|
||||
"""
|
||||
if name is None:
|
||||
name = f"compute_{self.compute_cnt}"
|
||||
self.compute_cnt += 1
|
||||
compute_node = ComputeNode(
|
||||
name=name, fn=op,
|
||||
element_output=self.element_compute,
|
||||
element_compute=self.element_compute)
|
||||
self.add_node(compute_node)
|
||||
return compute_node.name
|
||||
|
||||
def add_layout_node(self, op, kwargs, name=None):
|
||||
"""
|
||||
Add a layout node.
|
||||
:param op: the layout op
|
||||
:type op: evt_ops
|
||||
:param name: the node name (optional)
|
||||
:type name: str
|
||||
:return: the name of the layout node
|
||||
"""
|
||||
if name is None:
|
||||
name = f"layout_{self.layout_cnt}"
|
||||
self.layout_cnt += 1
|
||||
layout_node = LayoutNode(name=name, fn=op, kwargs=kwargs)
|
||||
self.add_node(layout_node)
|
||||
return layout_node.name
|
||||
|
||||
def add_store_node(self, name):
|
||||
store_node = StoreNode(name)
|
||||
self.add_node(store_node)
|
||||
|
||||
#
|
||||
# Visualization The DAG IR
|
||||
#
|
||||
|
||||
def visualize(self, name="dag_ir"):
|
||||
"""
|
||||
Visualize the dag ir with svg file
|
||||
:param name: the name of the graph
|
||||
"""
|
||||
drawer = EVTGraphDrawer(self.dag_ir, name)
|
||||
try:
|
||||
for name, graph in drawer.get_dot_graph():
|
||||
graph.write_svg(f"./{name}.svg")
|
||||
except:
|
||||
raise RuntimeError(
|
||||
"'dot' is not found in path. GraphDrawer is disabled. "
|
||||
"Please install it with 'sudo apt-get install graphviz'."
|
||||
)
|
||||
|
||||
#
|
||||
# Get shared memory size
|
||||
#
|
||||
|
||||
def get_smem_size(self, tile_description):
|
||||
"""
|
||||
Get the shared memory size of the epilogue
|
||||
"""
|
||||
smem_size = GetSmemSize(self.dag_ir)(tile_description)
|
||||
return smem_size
|
||||
194
python/cutlass_cppgen/backend/evt/frontend/python_ast.py
Normal file
194
python/cutlass_cppgen/backend/evt/frontend/python_ast.py
Normal file
@@ -0,0 +1,194 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Python AST frontend that parses input into DAG IR
|
||||
"""
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
|
||||
from cutlass_library import DataType
|
||||
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen.backend.evt.frontend.frontend_base import EVTFrontendBase
|
||||
from cutlass_cppgen.backend.epilogue import identity, relu, tanh, sigmoid, silu, hardswish, gelu
|
||||
from cutlass_cppgen.backend.library import FunctionalOp
|
||||
|
||||
|
||||
class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor):
|
||||
def __init__(self, cc, element_compute=DataType.f32, **kwargs):
|
||||
super().__init__(cc, element_compute, **kwargs)
|
||||
# Flags
|
||||
# If this state is True, visit_Constant returns values without creating imm node
|
||||
self.no_imm = False
|
||||
self.visiting_return = False
|
||||
|
||||
def parse(self, example_inputs):
|
||||
self.example_inputs = example_inputs
|
||||
self.source = textwrap.dedent(inspect.getsource(self.__call__))
|
||||
self.ast = ast.parse(self.source)
|
||||
self.visit(self.ast)
|
||||
|
||||
#
|
||||
# Helper functions
|
||||
#
|
||||
@staticmethod
|
||||
def ast_op_to_bindings(op):
|
||||
mapping = {
|
||||
ast.Add: FunctionalOp.Plus,
|
||||
ast.Sub: FunctionalOp.Minus,
|
||||
ast.Mult: FunctionalOp.Multiplies,
|
||||
ast.Div: FunctionalOp.Divides,
|
||||
"maximum": FunctionalOp.Maximum,
|
||||
"minimum": FunctionalOp.Minimum,
|
||||
"identity": identity.binding_type,
|
||||
"relu": relu.binding_type,
|
||||
"tanh": tanh.binding_type,
|
||||
"sigmoid": sigmoid.binding_type,
|
||||
"silu": silu.binding_type,
|
||||
"hardswish": hardswish.binding_type,
|
||||
"gelu": gelu.binding_type,
|
||||
"multiply_add": FunctionalOp.MultiplyAdd,
|
||||
"sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd),
|
||||
"max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum),
|
||||
"exp": FunctionalOp.Exp
|
||||
}
|
||||
return mapping[op]
|
||||
|
||||
#
|
||||
# Visiting different node types
|
||||
#
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef):
|
||||
# Visit args and register load nodes
|
||||
for arg in node.args.args:
|
||||
self.visit(arg)
|
||||
for expr in node.body:
|
||||
self.visit(expr)
|
||||
|
||||
def visit_arg(self, node: ast.arg):
|
||||
# Name of the argument
|
||||
name = node.arg
|
||||
try:
|
||||
example_tensor = self.example_inputs[name]
|
||||
except:
|
||||
raise RuntimeError(f"Example input for {name} is not provided.")
|
||||
|
||||
self.add_load_node(name, example_tensor)
|
||||
|
||||
def visit_Name(self, node: ast.Name):
|
||||
return node.id
|
||||
|
||||
def visit_Constant(self, node: ast.Constant):
|
||||
if self.no_imm:
|
||||
return node.value
|
||||
else:
|
||||
name = self.add_imm(node.value)
|
||||
return name
|
||||
|
||||
def visit_Tuple(self, node: ast.Tuple):
|
||||
results = []
|
||||
for elt in node.elts:
|
||||
results.append(self.visit(elt))
|
||||
return tuple(results)
|
||||
|
||||
def visit_keyword(self, node: ast.keyword):
|
||||
return {node.arg: self.visit(node.value)}
|
||||
|
||||
def visit_BinOp(self, node: ast.BinOp):
|
||||
if self.visiting_return:
|
||||
raise SyntaxError("Return value cannot be an expression")
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.right)
|
||||
op = self.ast_op_to_bindings(type(node.op))
|
||||
name = self.add_compute_node(op)
|
||||
|
||||
# Add edges
|
||||
# The edge weights are used to sort the input args
|
||||
self.add_edge(lhs, name, weight=0)
|
||||
self.add_edge(rhs, name, weight=1)
|
||||
return name
|
||||
|
||||
def visit_Assign(self, node: ast.BinOp):
|
||||
target = self.visit(node.targets[0])
|
||||
value = self.visit(node.value)
|
||||
# Create the assign node
|
||||
self.add_store_node(target)
|
||||
|
||||
# Add edges
|
||||
self.add_edge(value, target)
|
||||
return target
|
||||
|
||||
def visit_Call(self, node: ast.Call):
|
||||
if self.visiting_return:
|
||||
raise SyntaxError("Return value cannot be an expression")
|
||||
func = self.visit(node.func)
|
||||
args = [self.visit(arg) for arg in node.args]
|
||||
|
||||
if func in self.layout_fns.keys():
|
||||
# Parse kwargs
|
||||
# By default, visiting imm automatically creates a load node
|
||||
# However, in function call, keyword args are used to set
|
||||
# specific function attributes such as indices for permute
|
||||
# So no_imm is set to True temporarily
|
||||
self.no_imm = True
|
||||
kwargs = {}
|
||||
for kw in node.keywords:
|
||||
kwargs.update(self.visit(kw))
|
||||
self.no_imm = False
|
||||
op = self.layout_fns[func]
|
||||
name = self.add_layout_node(op, kwargs)
|
||||
else:
|
||||
op = self.ast_op_to_bindings(func)
|
||||
name = self.add_compute_node(op)
|
||||
|
||||
# Add edges
|
||||
for idx, arg in enumerate(args):
|
||||
self.add_edge(arg, name, weight=idx)
|
||||
return name
|
||||
|
||||
def visit_Return(self, node: ast.Return):
|
||||
self.visiting_return = True
|
||||
results = self.visit(node.value)
|
||||
self.visiting_return = False
|
||||
self.return_names = results
|
||||
if not isinstance(results, tuple):
|
||||
results = (results,)
|
||||
for rst in results:
|
||||
try:
|
||||
example_tensor = self.example_inputs[rst]
|
||||
except:
|
||||
raise RuntimeError(f"Example input for {rst} is not provided.")
|
||||
self.set_store_tensor(rst, example_tensor)
|
||||
self.mark_output(rst)
|
||||
53
python/cutlass_cppgen/backend/evt/ir/__init__.py
Normal file
53
python/cutlass_cppgen/backend/evt/ir/__init__.py
Normal file
@@ -0,0 +1,53 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode, ComputeImpl
|
||||
from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR
|
||||
from cutlass_cppgen.backend.evt.ir.layout_nodes import LayoutNode
|
||||
from cutlass_cppgen.backend.evt.ir.load_nodes import (
|
||||
LoadNode,
|
||||
AccumulatorImpl,
|
||||
LoadSrcImpl,
|
||||
AuxLoadImpl,
|
||||
RowBroadcastImpl,
|
||||
ColumnBroadcastImpl,
|
||||
ScalarBroadcastImpl
|
||||
)
|
||||
from cutlass_cppgen.backend.evt.ir.node import TopoVisitorNode, NoOpImpl
|
||||
from cutlass_cppgen.backend.evt.ir.store_nodes import (
|
||||
StoreNode,
|
||||
StoreDImpl,
|
||||
AuxStoreImpl,
|
||||
ColumnReductionImpl,
|
||||
RowReductionImpl,
|
||||
ScalarReductionImpl
|
||||
)
|
||||
91
python/cutlass_cppgen/backend/evt/ir/compute_nodes.py
Normal file
91
python/cutlass_cppgen/backend/evt/ir/compute_nodes.py
Normal file
@@ -0,0 +1,91 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Python registration for compute nodes in EVT
|
||||
"""
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase
|
||||
from cutlass_cppgen.backend.library import FloatRoundStyle
|
||||
|
||||
|
||||
class ComputeImplBase(ImplBase):
|
||||
"""
|
||||
Base class for compute implementation
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
|
||||
|
||||
class ComputeImpl(ComputeImplBase):
|
||||
"""
|
||||
Implementation for Compute Node
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
|
||||
self.fn = node.fn
|
||||
self.element_output = node.element_output
|
||||
self.element_compute = node.element_compute
|
||||
self.round_style = node.round_style
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
return True
|
||||
|
||||
|
||||
class ComputeNode(NodeBase):
|
||||
"""
|
||||
Compute Node in DAG IR
|
||||
"""
|
||||
possible_impls = [
|
||||
ComputeImpl
|
||||
]
|
||||
def __init__(
|
||||
self, name: str, fn, element_output,
|
||||
element_compute,
|
||||
round_style=FloatRoundStyle.ToNearest) -> None:
|
||||
super().__init__(name)
|
||||
self.op = "compute"
|
||||
self.fn = fn
|
||||
self.element_compute = element_compute
|
||||
self.round_style = round_style
|
||||
|
||||
def type_propagation(self, *args, **kwargs):
|
||||
"""
|
||||
Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
|
||||
"""
|
||||
self.element = self.element_compute
|
||||
# In general, the compute nodes have element_output = element_compute
|
||||
# In certain cases like producer of D it is overwritten by other passes
|
||||
if not hasattr(self, "element_output"):
|
||||
self.element_output = self.element
|
||||
254
python/cutlass_cppgen/backend/evt/ir/dag_ir.py
Normal file
254
python/cutlass_cppgen/backend/evt/ir/dag_ir.py
Normal file
@@ -0,0 +1,254 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
DAG IR used by Python EVT
|
||||
"""
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from cutlass_library import DataType
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode
|
||||
from cutlass_cppgen.backend.evt.ir.node import NodeBase
|
||||
from cutlass_cppgen.backend.library import ActivationOp
|
||||
from cutlass_cppgen.backend.utils import device_cc
|
||||
|
||||
|
||||
class DAGIR:
|
||||
"""
|
||||
``DAGIR`` is the main data structure used in the EVT Intermediate Representation.
|
||||
It consists of a series of ``Node`` s, each representing epilogue visitor nodes.
|
||||
|
||||
In the DAGIR, ``node`` is an string of its name. ``node_meta`` is the underlying class of the node
|
||||
"""
|
||||
def __init__(self, cc, element_compute=DataType.f32) -> None:
|
||||
# The EVT DAGIR is managed through the nextworkX Digraph class
|
||||
self._graph = nx.DiGraph()
|
||||
|
||||
self.element_compute = element_compute
|
||||
|
||||
self.reduction_names = []
|
||||
|
||||
self.cc = cc
|
||||
|
||||
self.identity_counter = 0
|
||||
|
||||
#
|
||||
# IR manipulator
|
||||
#
|
||||
|
||||
def add_node(self, meta: NodeBase):
|
||||
"""
|
||||
Add a node to dag ir
|
||||
"""
|
||||
if self.has_node(meta.name):
|
||||
raise SyntaxError(f"Variable '{meta.name}' cannot be defined twice.")
|
||||
self._graph.add_node(meta.name, meta=meta)
|
||||
|
||||
def add_edge(self, src: str, dst: str, weight: int=0):
|
||||
"""
|
||||
Add an edge src -> dst to dag ir with weight
|
||||
"""
|
||||
if not self.has_node(src):
|
||||
raise SyntaxError(f"Variable '{src}' is undefined.")
|
||||
if not self.has_node(dst):
|
||||
raise SyntaxError(f"Variable '{dst}' is undefined.")
|
||||
|
||||
if self._graph.has_edge(src, dst):
|
||||
# The DiGraph doesn't support multiple edges between two nodes
|
||||
# We insert an identity node in such case as a workaround
|
||||
identity_name = f"autogen_identity_{self.identity_counter}"
|
||||
self.identity_counter += 1
|
||||
compute_node = ComputeNode(
|
||||
name=identity_name, fn=ActivationOp.Identity,
|
||||
element_output=self.element_compute,
|
||||
element_compute=self.element_compute)
|
||||
self.add_node(compute_node)
|
||||
self.add_edge(src, identity_name, 0)
|
||||
self.add_edge(identity_name, dst, weight)
|
||||
else:
|
||||
self._graph.add_edge(src, dst, weight=weight)
|
||||
|
||||
def remove_node(self, node: str):
|
||||
"""
|
||||
Remove node from dag ir
|
||||
"""
|
||||
self._graph.remove_node(node)
|
||||
|
||||
def remove_edge(self, src: str, dst: str):
|
||||
"""
|
||||
Remove edge src -> dst
|
||||
"""
|
||||
self._graph.remove_edge(src, dst)
|
||||
|
||||
#
|
||||
# Helper functions for getting attrs
|
||||
#
|
||||
|
||||
def has_node(self, node: str) -> bool:
|
||||
"""
|
||||
Check if the node is in the graph
|
||||
"""
|
||||
return self._graph.has_node(node)
|
||||
|
||||
def in_degree(self, node: str):
|
||||
"""
|
||||
Get the input degree of node
|
||||
"""
|
||||
return self._graph.in_degree(node)
|
||||
|
||||
def in_edges(self, node: str):
|
||||
"""
|
||||
Get the input edges of node
|
||||
"""
|
||||
return [edge for edge in self._graph.in_edges(node)]
|
||||
|
||||
def out_degree(self, node: str):
|
||||
"""
|
||||
Get the output degree of node
|
||||
"""
|
||||
return self._graph.out_degree(node)
|
||||
|
||||
def out_edges(self, node: str):
|
||||
"""
|
||||
Get the output edges of node
|
||||
"""
|
||||
return [edge for edge in self._graph.out_edges(node)]
|
||||
|
||||
def get_node_meta(self, node: str):
|
||||
"""
|
||||
Get the meta data of the node
|
||||
"""
|
||||
return self._graph.nodes[node]["meta"]
|
||||
|
||||
def get_edge_weight(self, src, dst):
|
||||
"""
|
||||
Get the edge weight of edge src->dst
|
||||
"""
|
||||
return self._graph.get_edge_data(src, dst)["weight"]
|
||||
|
||||
#
|
||||
# High-level helper functions
|
||||
#
|
||||
|
||||
def all_reachable_nodes(self, node: str):
|
||||
"""
|
||||
Get all the nodes reachable from the current node (exclude)
|
||||
"""
|
||||
return list(nx.dfs_preorder_nodes(self._graph, source=node))
|
||||
|
||||
def get_users(self, node: str):
|
||||
"""
|
||||
Get all users of the current node
|
||||
"""
|
||||
return [edge[1] for edge in self.out_edges(node)]
|
||||
|
||||
def get_all_inputs(self, node: str):
|
||||
"""
|
||||
Get all the input nodes sorted by edge weight
|
||||
"""
|
||||
in_edges = self.in_edges(node)
|
||||
edge_weights = [self.get_edge_weight(*edge) for edge in in_edges]
|
||||
return [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))]
|
||||
|
||||
def get_all_inputs_meta(self, node: str):
|
||||
"""
|
||||
Get all the input node metas sorted by edge weight
|
||||
"""
|
||||
return [self.get_node_meta(input_node) for input_node in self.get_all_inputs(node)]
|
||||
|
||||
def replace_all_uses_with(self, node1, node2):
|
||||
"""
|
||||
Replace all uses of node1 with node2
|
||||
"""
|
||||
for edge in self.out_edges(node1):
|
||||
weight = self.get_edge_weight(*edge)
|
||||
user = edge[1]
|
||||
self.add_edge(node2, user, weight)
|
||||
self.remove_edge(node1, user)
|
||||
self.remove_node(node1)
|
||||
|
||||
#
|
||||
# Node accessor
|
||||
#
|
||||
def nodes_topological_order(self):
|
||||
"""
|
||||
Get the nodes in the unique lexicographical topological order
|
||||
It generates a unique ordering of nodes by first sorting topologically
|
||||
and then additionally by sorting lexicographically.
|
||||
|
||||
Although topological_sort alone also works, this generates a unique key
|
||||
for each epilogue visitor pattern and ensures the compilation cache can be reused.
|
||||
:return: list[str]
|
||||
"""
|
||||
return list(nx.lexicographical_topological_sort(self._graph))
|
||||
|
||||
def node_metas_topological_order(self):
|
||||
"""
|
||||
Get the node metas in topological order
|
||||
:return: list[NodeBase]
|
||||
"""
|
||||
return [self.get_node_meta(node) for node in self.nodes_topological_order()]
|
||||
|
||||
@property
|
||||
def nodes(self):
|
||||
"""
|
||||
Get all nodes
|
||||
:return: list[str]
|
||||
"""
|
||||
return list(self._graph.nodes)
|
||||
|
||||
@property
|
||||
def nodes_meta(self):
|
||||
"""
|
||||
Get all node metas
|
||||
:return: list[NodeBase]
|
||||
"""
|
||||
return [data[1]['meta'] for data in self._graph.nodes.data()]
|
||||
|
||||
@property
|
||||
def edges(self):
|
||||
"""
|
||||
Get all edges
|
||||
:return: list[(str, str)]
|
||||
"""
|
||||
return list(self._graph.edges)
|
||||
|
||||
#
|
||||
# Path
|
||||
#
|
||||
def has_path(self, src: str, target: str) -> bool:
|
||||
"""
|
||||
Return True is a path exists from src to target
|
||||
"""
|
||||
return nx.has_path(self._graph, src, target)
|
||||
324
python/cutlass_cppgen/backend/evt/ir/layout_algorithm.py
Normal file
324
python/cutlass_cppgen/backend/evt/ir/layout_algorithm.py
Normal file
@@ -0,0 +1,324 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Layout algebras
|
||||
"""
|
||||
|
||||
from pycute import Layout, composition, make_layout, flatten, product
|
||||
|
||||
|
||||
def _infer_split(old_shape, new_shape):
|
||||
old_shape = _tuple_to_list(old_shape)
|
||||
new_shape = _tuple_to_list(new_shape)
|
||||
if len(old_shape) == 0 and len(new_shape) == 0:
|
||||
return []
|
||||
if len(old_shape) == 0:
|
||||
if product(tuple(new_shape)) != 1:
|
||||
raise ValueError("Invalid reshape size")
|
||||
else:
|
||||
return new_shape
|
||||
if len(new_shape) == 0:
|
||||
if product(tuple(old_shape)) != 1:
|
||||
raise ValueError("Invalid reshape size")
|
||||
else:
|
||||
return old_shape
|
||||
# This is done recursively by only process the last dimension at each time
|
||||
old_dim = old_shape[-1]
|
||||
new_dim = new_shape[-1]
|
||||
# Exact match
|
||||
if old_dim == new_dim:
|
||||
return _infer_split(old_shape[:-1], new_shape[:-1]) + [new_dim,]
|
||||
# Needs split
|
||||
if old_dim > new_dim and old_dim % new_dim == 0:
|
||||
residual = old_dim // new_dim
|
||||
return _infer_split(old_shape[:-1] + [residual,], new_shape[:-1]) + [new_dim,]
|
||||
# Needs merge
|
||||
if old_dim < new_dim and new_dim % old_dim == 0:
|
||||
residual = new_dim // old_dim
|
||||
return _infer_split(old_shape[:-1], new_shape[:-1] + [residual,]) + [old_dim,]
|
||||
|
||||
raise NotImplementedError(f"Unsupported split: {old_shape} -> {new_shape}")
|
||||
|
||||
def _infer_merge(flatten_shape, shape):
|
||||
flatten_shape = _tuple_to_list(flatten_shape)
|
||||
shape = _tuple_to_list(shape)
|
||||
idx_flat = 0
|
||||
merged_shape = []
|
||||
for dim in shape:
|
||||
# Exact match
|
||||
if dim == flatten_shape[idx_flat]:
|
||||
merged_shape.append(dim)
|
||||
idx_flat += 1
|
||||
# Need group
|
||||
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
|
||||
residual = dim
|
||||
group = []
|
||||
while(residual > 1):
|
||||
group.append(flatten_shape[idx_flat])
|
||||
residual = residual // flatten_shape[idx_flat]
|
||||
idx_flat += 1
|
||||
merged_shape.append(group)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
|
||||
|
||||
return merged_shape
|
||||
|
||||
def _list_to_tuple(nested_list):
|
||||
if isinstance(nested_list, list) or isinstance(nested_list, tuple):
|
||||
return tuple(_list_to_tuple(item) for item in nested_list)
|
||||
return nested_list
|
||||
|
||||
def _tuple_to_list(nested_tuple):
|
||||
if isinstance(nested_tuple, list) or isinstance(nested_tuple, tuple):
|
||||
return list(_tuple_to_list(item) for item in nested_tuple)
|
||||
return nested_tuple
|
||||
|
||||
def _reverse_tuple(nested_tuple: tuple):
|
||||
if isinstance(nested_tuple, tuple):
|
||||
return tuple([_reverse_tuple(item) for item in nested_tuple][::-1])
|
||||
return nested_tuple
|
||||
|
||||
def _get_first_lhs_nonzero_stride(stride_list, idx):
|
||||
for i in reversed(range(idx)):
|
||||
if stride_list[i] != 0:
|
||||
return i
|
||||
else:
|
||||
return None
|
||||
|
||||
def _get_first_rhs_nonzero_stride(stride_list, idx):
|
||||
for i in range(idx+1, len(stride_list)):
|
||||
if stride_list[i] != 0:
|
||||
return i
|
||||
else:
|
||||
return None
|
||||
|
||||
def reshape(layout, new_shape):
|
||||
"""
|
||||
General reshape of input layout.
|
||||
It takes two steps:
|
||||
1. split the dimensions of the old layout
|
||||
2. merge the splitted dimensions according to the new shape
|
||||
"""
|
||||
#
|
||||
# Step 1: Split the dimensions of the old layout
|
||||
#
|
||||
# 1.1 Flat old and new shape
|
||||
old_flatten_shape = list(flatten(layout.shape))
|
||||
new_flatten_shape = list(flatten(new_shape))
|
||||
|
||||
# 1.2 Infer the flatten splitted shape
|
||||
splitted_flatten_shape = _infer_split(old_flatten_shape, new_flatten_shape)
|
||||
|
||||
# 1.3 Unflat the splitted shape based on the old shape
|
||||
splited_shape = _infer_merge(splitted_flatten_shape, old_flatten_shape)
|
||||
|
||||
# 1.4 Infer the type of each split
|
||||
# If the split type is in row-major (R), the dimension list is reversed because
|
||||
# the cute::composition only support column-major split
|
||||
split_type = [] # the type of each split (ColumnMajor or RowMajor)
|
||||
permuted_splitted_shape = []
|
||||
old_flatten_stride = list(flatten(layout.stride))
|
||||
for idx, dim in enumerate(splited_shape):
|
||||
if not isinstance(dim, list):
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
else:
|
||||
lhs_stride = _get_first_lhs_nonzero_stride(old_flatten_stride, idx)
|
||||
rhs_stride = _get_first_rhs_nonzero_stride(old_flatten_stride, idx)
|
||||
# Special case for single tuple
|
||||
# Use column-major by default
|
||||
if lhs_stride is None and rhs_stride is None:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
else:
|
||||
if lhs_stride is not None and rhs_stride is not None:
|
||||
# We consider shape[idx]:stride[idx]
|
||||
# Case 1: stride[idx - 1] <= stride[idx] <= stride[idx + 1]: column major
|
||||
if lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
# Case 2: stride[idx - 1] > stride[idx] > stride[idx + 1]: row major
|
||||
elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride:
|
||||
permuted_splitted_shape.append([d for d in reversed(dim)])
|
||||
split_type.append("R")
|
||||
# Case 3: stride[idx - 1] <= stride[idx] > stride[idx + 1]: concave
|
||||
elif lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride:
|
||||
if lhs_stride >= rhs_stride:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
else:
|
||||
permuted_splitted_shape.append([d for d in reversed(dim)])
|
||||
split_type.append("R")
|
||||
# Case 4: stride[idx - 1] > stride[idx] <= stride[idx + 1]: concave
|
||||
elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride:
|
||||
if lhs_stride >= rhs_stride:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
else:
|
||||
permuted_splitted_shape.append([d for d in reversed(dim)])
|
||||
split_type.append("R")
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
elif lhs_stride is None:
|
||||
# Case 1: dim's stride < dim+1's stride, expand in column major
|
||||
if old_flatten_stride[idx] > rhs_stride:
|
||||
permuted_splitted_shape.append([d for d in reversed(dim)])
|
||||
split_type.append("R")
|
||||
else:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
else:
|
||||
# Case 1: dim's stride > dim-1's stride
|
||||
if old_flatten_stride[idx] < lhs_stride:
|
||||
permuted_splitted_shape.append([d for d in reversed(dim)])
|
||||
split_type.append("R")
|
||||
else:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
|
||||
# 1.4 Generate the splitted layout
|
||||
permuted_splitted_layout = composition(layout, Layout(_list_to_tuple(permuted_splitted_shape)))
|
||||
|
||||
# 1.5 Reverse the permutation in 1.4 before merge
|
||||
splitted_shape = []
|
||||
splitted_stride = []
|
||||
for shape_dim, stride_dim, type in zip(
|
||||
permuted_splitted_layout.shape,
|
||||
permuted_splitted_layout.stride,
|
||||
split_type):
|
||||
if type == "C":
|
||||
splitted_shape.append(shape_dim)
|
||||
splitted_stride.append(stride_dim)
|
||||
else:
|
||||
splitted_shape.append(tuple([d for d in reversed(shape_dim)]))
|
||||
splitted_stride.append(tuple([d for d in reversed(stride_dim)]))
|
||||
splitted_layout = Layout(tuple(splitted_shape), tuple(splitted_stride))
|
||||
|
||||
|
||||
#
|
||||
# Step 2: Merge the splitted dimensions according to the new shape
|
||||
#
|
||||
# 2.1 Merge layout
|
||||
merged_layout = composition(splitted_layout, Layout(new_shape))
|
||||
|
||||
# 2.2 Cleaning up
|
||||
output_layout = composition(merged_layout, Layout(new_shape))
|
||||
return output_layout
|
||||
|
||||
|
||||
def permutation(layout, permutation):
|
||||
"""
|
||||
Permute the layout
|
||||
"""
|
||||
new_shape = tuple([layout.shape[idx] for idx in permutation])
|
||||
new_stride = tuple([layout.stride[idx] for idx in permutation])
|
||||
return Layout(new_shape, new_stride)
|
||||
|
||||
|
||||
def _broadcast(layout, new_shape):
|
||||
if len(layout) == 1 and isinstance(new_shape, int):
|
||||
old_dim = layout.shape
|
||||
old_stride = layout.stride
|
||||
new_dim = new_shape
|
||||
if old_dim == new_dim:
|
||||
return Layout(old_dim, old_stride)
|
||||
elif old_dim == 1:
|
||||
return Layout(new_dim, 0)
|
||||
else:
|
||||
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {new_dim}")
|
||||
|
||||
# Align the dimensions
|
||||
old_shape = layout.shape
|
||||
if isinstance(old_shape, int):
|
||||
old_shape = (old_shape,)
|
||||
sub_layouts = [layout,]
|
||||
else:
|
||||
sub_layouts = [sub_layout for sub_layout in layout]
|
||||
rhs_broadcast_layouts = [Layout(1, 0)] * (len(new_shape) - len(old_shape))
|
||||
# Get the broadcasted layout
|
||||
broadcast_layouts = []
|
||||
try:
|
||||
layout = make_layout(*sub_layouts, *rhs_broadcast_layouts)
|
||||
broadcast_layouts = []
|
||||
for idx, sub_layout in enumerate(layout):
|
||||
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
|
||||
except NotImplementedError:
|
||||
layout = make_layout(*rhs_broadcast_layouts, *sub_layouts)
|
||||
for idx, sub_layout in enumerate(layout):
|
||||
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
|
||||
return make_layout(*broadcast_layouts)
|
||||
|
||||
|
||||
def broadcast(layout, new_shape):
|
||||
"""
|
||||
Broadcast the new layout based on the input shape
|
||||
The broadcasted shape equals to the new shape
|
||||
The stride of broadcasted dimensions are 0
|
||||
"""
|
||||
return _broadcast(layout, new_shape)
|
||||
|
||||
|
||||
def debroadcast(layout, dims):
|
||||
"""
|
||||
Squeeze the 0-stride
|
||||
"""
|
||||
for dim in dims:
|
||||
if layout.stride[dim] != 0:
|
||||
raise ValueError(f"Dim{dim} cannot be debroadcasted as it has stride {layout.stride[dim]}")
|
||||
new_shape = tuple([s for idx, s in enumerate(layout.shape) if idx not in dims])
|
||||
new_stride = tuple([s for idx, s in enumerate(layout.stride) if idx not in dims])
|
||||
return Layout(new_shape, new_stride)
|
||||
|
||||
|
||||
def canonicalization_(shapes, strides):
|
||||
if isinstance(shapes, tuple):
|
||||
c_shapes = []
|
||||
c_strides = []
|
||||
for shape, stride in zip(shapes, strides):
|
||||
c_shape, c_stride = canonicalization_(shape, stride)
|
||||
c_shapes.append(c_shape)
|
||||
c_strides.append(c_stride)
|
||||
return tuple(c_shapes), tuple(c_strides)
|
||||
else:
|
||||
if shapes == 1:
|
||||
return 1, 0
|
||||
else:
|
||||
return shapes, strides
|
||||
|
||||
def canonicalization(layout):
|
||||
"""
|
||||
Canonicalize the input layout
|
||||
1. set the stride of shape "1" to 0
|
||||
"""
|
||||
new_shape, new_stride = canonicalization_(layout.shape, layout.stride)
|
||||
return Layout(new_shape, new_stride)
|
||||
336
python/cutlass_cppgen/backend/evt/ir/layout_nodes.py
Normal file
336
python/cutlass_cppgen/backend/evt/ir/layout_nodes.py
Normal file
@@ -0,0 +1,336 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Layout manipulation nodes and implementations
|
||||
|
||||
The layout Nodes change the layout of intermediate nodes in epilogue visitor graph
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from cutlass_library import LayoutType
|
||||
from pycute import product, flatten
|
||||
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list
|
||||
from cutlass_cppgen.backend.evt.ir.node import NodeBase
|
||||
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
|
||||
|
||||
|
||||
class PermutationImpl:
|
||||
"""
|
||||
Detailed implementation and helper functions for permutation
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
assert "indices" in node.kwargs.keys()
|
||||
self.indices = list(node.kwargs["indices"])
|
||||
self.inverse_indices = self.get_inverse_indices(self.indices)
|
||||
|
||||
def get_inverse_impl(self):
|
||||
inverse_impl = deepcopy(self)
|
||||
inverse_impl.indices = self.inverse_indices
|
||||
inverse_impl.inverse_indices = self.indices
|
||||
return inverse_impl
|
||||
|
||||
def update(self, shape):
|
||||
num_dim = len(shape)
|
||||
indices = self.indices
|
||||
num_old_dim = len(indices)
|
||||
# Add offset
|
||||
for i, idx in enumerate(indices):
|
||||
indices[i] = idx + num_dim - num_old_dim
|
||||
# Add broadcast dims
|
||||
for i in range(num_dim - num_old_dim):
|
||||
indices = [i,] + indices
|
||||
|
||||
self.indices = indices
|
||||
self.inverse_indices = self.get_inverse_indices(self.indices)
|
||||
|
||||
def get_inverse_indices(self, indices):
|
||||
"""
|
||||
Get the indices for inverse permutation
|
||||
"""
|
||||
num_dim = len(indices)
|
||||
inverse_indices = [0] * num_dim
|
||||
for i in range(num_dim):
|
||||
inverse_indices[indices[i]] = i
|
||||
return inverse_indices
|
||||
|
||||
def shape_propagation(self, input_node_meta):
|
||||
input_shape = input_node_meta.tensor.shape
|
||||
output_shape = tuple([input_shape[idx] for idx in self.indices])
|
||||
return output_shape
|
||||
|
||||
def broadcast(self, shape, node_meta: NodeBase):
|
||||
"""
|
||||
Broadcast the inputs based on current shape
|
||||
"""
|
||||
self.update(shape)
|
||||
inverse_shape = tuple([shape[idx] for idx in self.inverse_indices])
|
||||
node_meta.tensor.broadcast(inverse_shape)
|
||||
|
||||
def apply_to_user(self, usr_meta: NodeBase):
|
||||
"""
|
||||
Propagate the permutation to the users of the current nodes
|
||||
"""
|
||||
usr_meta.tensor.permute(self.inverse_indices)
|
||||
if hasattr(usr_meta, "store_tensor"):
|
||||
if usr_meta.store_tensor is not None:
|
||||
usr_meta.store_tensor.permute(self.inverse_indices)
|
||||
|
||||
def apply_to_input(self, input_meta: NodeBase):
|
||||
"""
|
||||
Propagate the permutation to inputs of the current nodes
|
||||
"""
|
||||
input_meta.tensor.permute(self.indices)
|
||||
if hasattr(input_meta, "store_tensor"):
|
||||
if input_meta.store_tensor is not None:
|
||||
input_meta.store_tensor.permute(self.indices)
|
||||
|
||||
|
||||
class ReshapeImpl:
|
||||
"""
|
||||
Detailed implementation and helper functions for reshape
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
self.node = node
|
||||
assert "new_shape" in node.kwargs.keys()
|
||||
self.output_shape = _list_to_tuple(node.kwargs["new_shape"])
|
||||
|
||||
def get_inverse_impl(self):
|
||||
inverse_impl = deepcopy(self)
|
||||
inverse_impl.output_shape = self.input_shape
|
||||
inverse_impl.input_shape = self.output_shape
|
||||
return inverse_impl
|
||||
|
||||
def shape_propagation(self, input_node_meta):
|
||||
self.input_shape = input_node_meta.tensor.shape
|
||||
return _list_to_tuple(self.output_shape)
|
||||
|
||||
def broadcast(self, shape, node_meta: NodeBase):
|
||||
"""
|
||||
Broadcast the inputs based on current shape.
|
||||
"""
|
||||
# Step 1: infer split
|
||||
flatten_split_shape = self.infer_split(flatten(self.input_shape), flatten(self.output_shape))
|
||||
split_input_shape = self.infer_merge(flatten_split_shape, self.input_shape)
|
||||
split_output_shape = self.infer_merge(flatten_split_shape, self.output_shape)
|
||||
|
||||
# broadcast shape -> split_output_shape -> flatten_split_shape
|
||||
if len(shape) - len(split_output_shape) > 0:
|
||||
for _ in range(len(shape) - len(split_output_shape)):
|
||||
split_output_shape = [1,] + split_output_shape
|
||||
flatten_split_shape = [1,] + flatten_split_shape
|
||||
split_input_shape = [1,] + split_input_shape
|
||||
broadcast_factor = []
|
||||
for dim, old_dim in zip(shape, split_output_shape):
|
||||
if not isinstance(dim, list):
|
||||
dim = [dim,]
|
||||
if not isinstance(old_dim, list):
|
||||
old_dim = [old_dim,]
|
||||
if product(tuple(dim)) == product(tuple(old_dim)):
|
||||
broadcast_factor += [1] * len(old_dim)
|
||||
elif product(tuple(old_dim)) == 1:
|
||||
assert len(dim) == 1
|
||||
broadcast_factor.append(dim[0])
|
||||
else:
|
||||
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {dim}")
|
||||
|
||||
# flatten_split_shape -> split_input_shape
|
||||
factor_idx = 0
|
||||
broadcast_split_input_shape = []
|
||||
for dim in split_input_shape:
|
||||
if isinstance(dim, list):
|
||||
new_dim = []
|
||||
for d in dim:
|
||||
new_dim.append(d * broadcast_factor[factor_idx])
|
||||
factor_idx += 1
|
||||
broadcast_split_input_shape.append(new_dim)
|
||||
else:
|
||||
broadcast_split_input_shape.append(dim * broadcast_factor[factor_idx])
|
||||
factor_idx += 1
|
||||
broadcast_split_input_shape = _list_to_tuple(broadcast_split_input_shape)
|
||||
node_meta.tensor.reshape(_list_to_tuple(split_input_shape))
|
||||
node_meta.tensor.broadcast(broadcast_split_input_shape)
|
||||
# Last reshape op to clean up
|
||||
broadcast_input_shape = tuple([product(dim) for dim in broadcast_split_input_shape])
|
||||
node_meta.tensor.reshape(broadcast_input_shape)
|
||||
# Update the input shape and output shape
|
||||
self.input_shape = _list_to_tuple(node_meta.tensor.shape)
|
||||
self.output_shape = _list_to_tuple(shape)
|
||||
|
||||
def apply_to_user(self, user_meta: NodeBase):
|
||||
"""
|
||||
Propagate the reshape to user nodes
|
||||
"""
|
||||
user_meta.tensor.reshape(tuple(self.input_shape))
|
||||
if hasattr(user_meta, "store_tensor"):
|
||||
if user_meta.store_tensor is not None:
|
||||
user_meta.store_tensor.reshape(tuple(self.input_shape))
|
||||
|
||||
def apply_to_input(self, input_meta: NodeBase):
|
||||
"""
|
||||
Propagate the reshape to input nodes
|
||||
"""
|
||||
input_meta.tensor.reshape(tuple(self.output_shape))
|
||||
if hasattr(input_meta, "store_tensor"):
|
||||
if input_meta.store_tensor is not None:
|
||||
input_meta.store_tensor.reshape(tuple(self.output_shape))
|
||||
|
||||
#
|
||||
# Helper functions
|
||||
#
|
||||
|
||||
def infer_split(self, input_shape, output_shape):
|
||||
"""
|
||||
Infer the flatten splitted shape that can be merged to both input_shape and output_shape
|
||||
"""
|
||||
input_shape = _tuple_to_list(input_shape)
|
||||
output_shape = _tuple_to_list(output_shape)
|
||||
if len(input_shape) == 0 and len(output_shape) == 0:
|
||||
return []
|
||||
if len(input_shape) == 0:
|
||||
if product(tuple(output_shape)) != 1:
|
||||
raise ValueError("Invalid reshape size")
|
||||
else:
|
||||
return output_shape
|
||||
if len(output_shape) == 0:
|
||||
if product(tuple(input_shape)) != 1:
|
||||
raise ValueError("Invalid reshape size")
|
||||
else:
|
||||
return input_shape
|
||||
# This is done recursively by only process the last dimension at each time
|
||||
old_dim = input_shape[-1]
|
||||
new_dim = output_shape[-1]
|
||||
# Exact match
|
||||
if old_dim == new_dim:
|
||||
return self.infer_split(input_shape[:-1], output_shape[:-1]) + [new_dim,]
|
||||
# Needs split
|
||||
if old_dim > new_dim and old_dim % new_dim == 0:
|
||||
residual = old_dim // new_dim
|
||||
return self.infer_split(input_shape[:-1] + [residual,], output_shape[:-1]) + [new_dim,]
|
||||
# Needs merge
|
||||
if old_dim < new_dim and new_dim % old_dim == 0:
|
||||
residual = new_dim // old_dim
|
||||
return self.infer_split(input_shape[:-1], output_shape[:-1] + [residual,]) + [old_dim,]
|
||||
|
||||
raise NotImplementedError(f"Unsupported split: {input_shape} -> {output_shape}")
|
||||
|
||||
def infer_merge(self, flatten_shape, shape):
|
||||
flatten_shape = _tuple_to_list(flatten_shape)
|
||||
shape = _tuple_to_list(shape)
|
||||
idx_flat = len(flatten_shape) - 1
|
||||
merged_shape = []
|
||||
for dim in reversed(shape):
|
||||
# Exact match
|
||||
if dim == flatten_shape[idx_flat]:
|
||||
merged_shape.append(dim)
|
||||
idx_flat -= 1
|
||||
# need group
|
||||
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
|
||||
residual = dim
|
||||
group = []
|
||||
while(residual > 1):
|
||||
group.append(flatten_shape[idx_flat])
|
||||
residual = residual // flatten_shape[idx_flat]
|
||||
idx_flat -= 1
|
||||
merged_shape.append(group[::-1])
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
|
||||
|
||||
return merged_shape[::-1]
|
||||
|
||||
|
||||
class LayoutNode(NodeBase):
|
||||
"""
|
||||
Layout manipulation nodes
|
||||
"""
|
||||
fn_to_impl = {
|
||||
"permute": PermutationImpl,
|
||||
"reshape": ReshapeImpl
|
||||
}
|
||||
def __init__(self, name: str, fn, kwargs: dict) -> None:
|
||||
super().__init__(name)
|
||||
self.op = "layout"
|
||||
self.fn = fn
|
||||
self.kwargs = kwargs
|
||||
self.underlying_impl = self.fn_to_impl[self.fn.__name__](self)
|
||||
|
||||
def get_inverse_node(self):
|
||||
inverse_node = deepcopy(self)
|
||||
inverse_node.underlying_impl = self.underlying_impl.get_inverse_impl()
|
||||
return inverse_node
|
||||
|
||||
def shape_propagation(self, input_node_metas):
|
||||
if self._tensor is not None:
|
||||
return
|
||||
assert len(input_node_metas) == 1, "Layout node can only have one input node"
|
||||
|
||||
output_shape = self.underlying_impl.shape_propagation(input_node_metas[0])
|
||||
|
||||
self._tensor = Tensor(
|
||||
element=self.element_output,
|
||||
shape=output_shape, layout_tag=LayoutType.RowMajor
|
||||
)
|
||||
|
||||
return super().shape_propagation(input_node_metas)
|
||||
|
||||
def type_propagation(self, input_node_metas: 'list[NodeBase]'):
|
||||
"""
|
||||
The store nodes has element_output = element_input
|
||||
"""
|
||||
assert len(input_node_metas) == 1, "Layout node can only have one input node"
|
||||
self.element_output = input_node_metas[0].element_output
|
||||
|
||||
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
|
||||
"""
|
||||
Propagate the broadcast in the reversed topological order
|
||||
"""
|
||||
if self.tensor is None:
|
||||
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
|
||||
shape = self.tensor.shape
|
||||
|
||||
for child in input_node_metas:
|
||||
self.underlying_impl.broadcast(shape, child)
|
||||
|
||||
def apply_to_user(self, usr_meta: NodeBase):
|
||||
"""
|
||||
Propagate the permutation to user nodes
|
||||
"""
|
||||
self.underlying_impl.apply_to_user(usr_meta)
|
||||
|
||||
def apply_to_input(self, input_meta: NodeBase):
|
||||
"""
|
||||
Propagate the permutation to input nodes
|
||||
"""
|
||||
self.underlying_impl.apply_to_input(input_meta)
|
||||
294
python/cutlass_cppgen/backend/evt/ir/load_nodes.py
Normal file
294
python/cutlass_cppgen/backend/evt/ir/load_nodes.py
Normal file
@@ -0,0 +1,294 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Load nodes and implementations
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
|
||||
from cutlass_cppgen.backend.c_types import tuple_factory
|
||||
from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
|
||||
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase
|
||||
|
||||
|
||||
class LoadImplBase(ImplBase):
|
||||
"""
|
||||
Base class for load node implementations
|
||||
"""
|
||||
reserved_names = ["accum", "C"]
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.element = node.element
|
||||
self.element_output = node.element_output
|
||||
self.stride = node.tensor.stride
|
||||
|
||||
|
||||
class AccumulatorImpl(LoadImplBase):
|
||||
"""
|
||||
Accumulator node implementation
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
return node.name == "accum" and node.tensor.shape == problem_size
|
||||
|
||||
|
||||
class LoadSrcImpl(LoadImplBase):
|
||||
"""
|
||||
Load C implementation
|
||||
"""
|
||||
@property
|
||||
def name_camel(self) -> str:
|
||||
return "TensorC"
|
||||
|
||||
@property
|
||||
def argument_type_c(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_C", ctypes.c_void_p),
|
||||
("stride_C", tuple_type)
|
||||
]
|
||||
def __init__(self, ptr) -> None:
|
||||
self.ptr_C = ptr
|
||||
self.stride_C = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
return node.name == "C" and node.tensor.shape == problem_size
|
||||
|
||||
|
||||
class AuxLoadImpl(LoadImplBase):
|
||||
"""
|
||||
Load arbitrary tensor
|
||||
"""
|
||||
@property
|
||||
def argument_type(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
element_type = self.element
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_aux", ctypes.c_void_p),
|
||||
("null_default", dtype2ctype[element_type]),
|
||||
("dAux", tuple_type)
|
||||
]
|
||||
def __init__(self, kwargs) -> None:
|
||||
ptr = kwargs[name]
|
||||
self.ptr_aux = ptr
|
||||
self.null_default = to_ctype_value(0, element_type)
|
||||
self.dAux = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.name in LoadImplBase.reserved_names:
|
||||
return False
|
||||
strideMN = node.tensor.stride[-2:]
|
||||
if (strideMN[0] == 1 and strideMN[1] != 0 or
|
||||
strideMN[0] != 0 and strideMN[1] == 1 ):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class RowBroadcastImpl(LoadImplBase):
|
||||
"""
|
||||
Broadcast a row vector
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.stride_dtype = "int"
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
element_type = self.element
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_row", ctypes.c_void_p),
|
||||
("null_default", dtype2ctype[element_type]),
|
||||
("dRow", tuple_type)
|
||||
]
|
||||
def __init__(self, kwargs) -> None:
|
||||
ptr = kwargs[name]
|
||||
self.ptr_row = ptr
|
||||
self.null_default = to_ctype_value(0, element_type)
|
||||
self.dRow = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.name in LoadImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.tensor.stride[-2:]
|
||||
if strideMN == (0, 1):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class ColumnBroadcastImpl(LoadImplBase):
|
||||
"""
|
||||
Broadcast a column vector
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.stride_dtype = "int"
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
element_type = self.element
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_col", ctypes.c_void_p),
|
||||
("null_default", dtype2ctype[element_type]),
|
||||
("dCol", tuple_type)
|
||||
]
|
||||
def __init__(self, kwargs) -> None:
|
||||
ptr = kwargs[name]
|
||||
self.ptr_col = int(ptr)
|
||||
self.null_default = to_ctype_value(0, element_type)
|
||||
self.dCol = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.name in LoadImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.tensor.stride[-2:]
|
||||
if strideMN == (1, 0):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class ScalarBroadcastImpl(LoadImplBase):
|
||||
"""
|
||||
Broadcast a scalar
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.stride_dtype = "int"
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
element_type = self.element
|
||||
|
||||
if self.tensor.is_constant:
|
||||
value = self.tensor.value
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("scalars", dtype2ctype[element_type]),
|
||||
("scalar_ptrs", ctypes.c_void_p),
|
||||
("dScalar", tuple_type)
|
||||
]
|
||||
def __init__(self, kwargs) -> None:
|
||||
self.scalars = to_ctype_value(value, element_type)
|
||||
self.scalar_ptrs = 0
|
||||
self.dScalar = tuple_type(stride_mnl)
|
||||
|
||||
else:
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("scalars", dtype2ctype[element_type]),
|
||||
("scalar_ptrs", ctypes.c_void_p),
|
||||
("dScalar", tuple_type)
|
||||
]
|
||||
def __init__(self, kwargs) -> None:
|
||||
scalar_or_ptr = kwargs[name]
|
||||
if isinstance(scalar_or_ptr, float):
|
||||
self.scalars = to_ctype_value(scalar_or_ptr, element_type)
|
||||
self.scalar_ptrs = 0
|
||||
else:
|
||||
self.scalar_ptrs = int(scalar_or_ptr)
|
||||
|
||||
self.dScalar = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.name in LoadImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.tensor.stride[-2:]
|
||||
if strideMN == (0, 0):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class LoadNode(NodeBase):
|
||||
"""
|
||||
Load Node
|
||||
"""
|
||||
cnt = 0
|
||||
possible_impls = [
|
||||
AccumulatorImpl, LoadSrcImpl, AuxLoadImpl,
|
||||
RowBroadcastImpl, ColumnBroadcastImpl,
|
||||
ScalarBroadcastImpl
|
||||
]
|
||||
def __init__(self, name: str) -> None:
|
||||
if name is None:
|
||||
name = f"load{LoadNode.cnt}"
|
||||
LoadNode.cnt += 1
|
||||
super().__init__(name)
|
||||
self.op = "load"
|
||||
|
||||
def type_propagation(self, *args, **kwargs):
|
||||
"""
|
||||
Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
|
||||
"""
|
||||
if self.tensor is None:
|
||||
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
|
||||
|
||||
self.element = self.tensor.element
|
||||
self.element_output = self.tensor.element
|
||||
306
python/cutlass_cppgen/backend/evt/ir/node.py
Normal file
306
python/cutlass_cppgen/backend/evt/ir/node.py
Normal file
@@ -0,0 +1,306 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Base & visitor classes of DAGIR Nodes
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
from re import sub
|
||||
|
||||
from cutlass_library import LayoutType
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple
|
||||
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
|
||||
|
||||
|
||||
class TupleEmitter:
|
||||
"""
|
||||
Emit the cute tuple to C++ code
|
||||
"""
|
||||
def __init__(self, stride_dtype):
|
||||
self.stride_dtype = stride_dtype
|
||||
|
||||
def emit(self, py_tuple):
|
||||
if isinstance(py_tuple, int):
|
||||
if py_tuple in [0, 1]:
|
||||
return f"cute::Int<{py_tuple}>"
|
||||
else:
|
||||
return f"{self.stride_dtype}"
|
||||
elif isinstance(py_tuple, tuple):
|
||||
decl = "cute::Stride<"
|
||||
for item in py_tuple:
|
||||
decl += self.emit(item) + ", "
|
||||
return decl[:-2] + ">"
|
||||
else:
|
||||
raise ValueError(f"TupleEmitter.emit only accepts tuple or int, got {type(py_tuple).__name__}")
|
||||
|
||||
|
||||
class ImplBase:
|
||||
"""
|
||||
Base class for Node Implementation
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
self.node = node
|
||||
self.name = node.name
|
||||
self.tensor = node.tensor
|
||||
self._type_decl = None
|
||||
self.tuple_emitter = TupleEmitter("int64_t")
|
||||
|
||||
@property
|
||||
def stride_dtype(self):
|
||||
return self.tuple_emitter.stride_dtype
|
||||
|
||||
@stride_dtype.setter
|
||||
def stride_dtype(self, stride_dtype):
|
||||
self.tuple_emitter.stride_dtype = stride_dtype
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
"""
|
||||
Match function used in get_underlying_impl
|
||||
"""
|
||||
raise NotImplementedError(f"The `match` function is not defined.")
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
"""
|
||||
Default class for Argument Type
|
||||
"""
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = []
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
return _Argument
|
||||
|
||||
@property
|
||||
def name_camel(self) -> str:
|
||||
"""
|
||||
Return the CamelCase name.
|
||||
"""
|
||||
return sub(r"(_|-)+", " ", self.name).title().replace(" ", "")
|
||||
|
||||
@property
|
||||
def stride_mnl(self):
|
||||
"""
|
||||
Typename StrideMNL
|
||||
"""
|
||||
stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
|
||||
return self.tuple_emitter.emit(stride)
|
||||
|
||||
def get_non_constant_stride(self, py_tuple):
|
||||
if isinstance(py_tuple, int):
|
||||
if py_tuple not in [0, 1]:
|
||||
return py_tuple
|
||||
else:
|
||||
return None
|
||||
non_constant_stride = []
|
||||
for item in py_tuple:
|
||||
item_out = self.get_non_constant_stride(item)
|
||||
if item_out:
|
||||
non_constant_stride.append(item_out)
|
||||
return tuple(non_constant_stride)
|
||||
|
||||
def get_stride_mnl(self):
|
||||
"""
|
||||
Get the non-zero stride mnl. This is used in argument construction
|
||||
"""
|
||||
stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
|
||||
return stride
|
||||
|
||||
def get_smem_size(self, *args, **kwargs):
|
||||
"""
|
||||
Get the shared memory size and alignment of current node
|
||||
"""
|
||||
return (0, 1)
|
||||
|
||||
|
||||
class NoOpImpl(ImplBase):
|
||||
"""
|
||||
The NoOpImpl does nothing but forward its input to users
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.op == "store":
|
||||
# Store that is not output is a No OP
|
||||
return not node.is_output
|
||||
|
||||
|
||||
class NodeBase:
|
||||
"""
|
||||
Base class of DAG Node
|
||||
"""
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
self.underlying_impl = None
|
||||
|
||||
self._tensor = None
|
||||
|
||||
# Whether the node is disabled for emit
|
||||
self.disabled = False
|
||||
|
||||
@property
|
||||
def name_camel(self) -> str:
|
||||
"""
|
||||
Return the CamelCase name.
|
||||
"""
|
||||
return self.underlying_impl.name_camel
|
||||
|
||||
@property
|
||||
def tensor(self) -> Tensor:
|
||||
"""
|
||||
Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
|
||||
"""
|
||||
return self._tensor
|
||||
|
||||
@tensor.setter
|
||||
def tensor(self, kwargs):
|
||||
"""
|
||||
Setting the tensor
|
||||
"""
|
||||
self._tensor = Tensor(**kwargs)
|
||||
|
||||
#
|
||||
# Helper functions for type/shape propagation
|
||||
#
|
||||
|
||||
def shape_propagation(self, input_node_metas):
|
||||
"""
|
||||
Infer shape from input nodes
|
||||
General Broadcasting Rules from NumPy
|
||||
When operating on two arrays, we compare their shapes element-wise.
|
||||
It starts with the trailing (i.e. rightmost) dimension and works its
|
||||
way left. Two dimensions are compatible when
|
||||
1. they are equal
|
||||
2. one of them is 1
|
||||
"""
|
||||
if self._tensor is not None:
|
||||
return
|
||||
|
||||
shape = None
|
||||
for src in input_node_metas:
|
||||
src_shape = src.tensor.shape
|
||||
if shape is None:
|
||||
shape = src_shape
|
||||
else:
|
||||
len_difference = len(shape) - len(src_shape)
|
||||
if len_difference > 0:
|
||||
for _ in range(len_difference):
|
||||
src_shape = [1, ] + list(src_shape)
|
||||
elif len_difference < 0:
|
||||
for _ in range(-len_difference):
|
||||
shape = [1, ] + list(shape)
|
||||
broadcasted_shape = []
|
||||
# Infer broadcast shape
|
||||
for shape_dim, src_dim in zip(reversed(shape), reversed(src_shape)):
|
||||
if shape_dim == 1:
|
||||
broadcasted_shape = [src_dim, ] + list(broadcasted_shape)
|
||||
elif src_dim == 1:
|
||||
broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
|
||||
elif shape_dim == src_dim:
|
||||
broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
|
||||
else:
|
||||
error_msg = "Dimension mismatch between "
|
||||
for src_ in input_node_metas:
|
||||
error_msg += f"{src_.name}{src_.tensor.shape}, "
|
||||
error_msg = error_msg[:-2] + "."
|
||||
raise RuntimeError(error_msg)
|
||||
shape = tuple(broadcasted_shape)
|
||||
|
||||
self._tensor = Tensor(element=self.element_output, shape=shape, layout_tag=LayoutType.RowMajor)
|
||||
|
||||
def type_propagation(self, *args, **kwargs):
|
||||
"""
|
||||
Each node is associated with two data types: `element` and `element_output`.
|
||||
The `element_output` is the type of return array of the node. The `element`
|
||||
has specific meaning for different node types.
|
||||
* Load Node: data type of tensor in gmem
|
||||
* Compute Node: element compute
|
||||
* Store Node: data type of tensor in gmem
|
||||
This function must be overloaded in the derived classes
|
||||
"""
|
||||
raise NotImplementedError(f"Function `type_propagation` is not overloaded in {self.__class__.__name__}")
|
||||
|
||||
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
|
||||
"""
|
||||
Propagate the broadcast in the reversed topological order.
|
||||
For example:
|
||||
C[l, m, n] = A[m, 1] + B[l, m, n]
|
||||
After the broadcast propagation, it will be come
|
||||
C[l, m, n] = A[l, m, n] + B[l, m, n]
|
||||
and each tensor will have a proper stride accessing the underlying tensor
|
||||
"""
|
||||
if self.tensor is None:
|
||||
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
|
||||
for child in input_node_metas:
|
||||
child.tensor.broadcast(self.tensor.shape)
|
||||
|
||||
def get_underlying_impl(self, problem_size: tuple):
|
||||
"""
|
||||
Get the underlying implementation of the current node.
|
||||
"""
|
||||
if self.tensor is None:
|
||||
raise RuntimeError(f"The Layout of node {self.name} is unknown. Please call PassShapeTypePropagation first.")
|
||||
|
||||
for impl in self.possible_impls:
|
||||
if impl.match(self, problem_size):
|
||||
self.underlying_impl = impl(self)
|
||||
break
|
||||
|
||||
if self.underlying_impl is None:
|
||||
raise NotImplementedError(f"No matching op for node {self.name} with stride {self.tensor.stride}.")
|
||||
|
||||
#
|
||||
# Visitor Nodes & Impls
|
||||
#
|
||||
|
||||
class TopoVisitorImpl(ImplBase):
|
||||
"""
|
||||
Impl for topological visitor
|
||||
"""
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node.output_node)
|
||||
self.name = node.name
|
||||
self.element_output = node.output_node.element_output
|
||||
|
||||
class TopoVisitorNode(NodeBase):
|
||||
def __init__(self, name: str, subgraph, output_node) -> None:
|
||||
super().__init__(name)
|
||||
self.subgraph = subgraph
|
||||
self.output_node = output_node
|
||||
self.op = "dag"
|
||||
self.underlying_impl = TopoVisitorImpl(self)
|
||||
277
python/cutlass_cppgen/backend/evt/ir/store_nodes.py
Normal file
277
python/cutlass_cppgen/backend/evt/ir/store_nodes.py
Normal file
@@ -0,0 +1,277 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Store node and implementations
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
|
||||
from cutlass_library import DataType
|
||||
|
||||
from cutlass_cppgen.backend.c_types import tuple_factory
|
||||
from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
|
||||
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase, NoOpImpl
|
||||
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
|
||||
from cutlass_cppgen.backend.library import FloatRoundStyle, FunctionalOp
|
||||
|
||||
|
||||
class StoreImplBase(ImplBase):
|
||||
"""
|
||||
Base class for store node implementation
|
||||
"""
|
||||
reserved_names = ["D"]
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.element = node.element
|
||||
self.element_output = node.element_output
|
||||
self.stride = node.store_tensor.stride
|
||||
|
||||
|
||||
class StoreDImpl(StoreImplBase):
|
||||
"""
|
||||
Store D implementation
|
||||
"""
|
||||
|
||||
@property
|
||||
def argument_type_d(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_D", ctypes.c_void_p),
|
||||
("stride_D", tuple_type)
|
||||
]
|
||||
def __init__(self, ptr: int) -> None:
|
||||
self.ptr_D = ptr
|
||||
self.stride_D = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.name == "D" and node.store_tensor.shape == problem_size:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class AuxStoreImpl(StoreImplBase):
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.round_style = FloatRoundStyle.ToNearest
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_aux", ctypes.c_void_p),
|
||||
("dAux", tuple_type)
|
||||
]
|
||||
def __init__(self, kwargs) -> None:
|
||||
ptr = kwargs[name]
|
||||
self.ptr_aux = ptr
|
||||
self.dAux = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if not node.is_output:
|
||||
return False
|
||||
if node.name in StoreImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.store_tensor.stride[-2:]
|
||||
if (strideMN[0] == 1 and strideMN[1] != 0 or
|
||||
strideMN[0] != 0 and strideMN[1] == 1 ):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class ReductionImplBase(StoreImplBase):
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.element = node.store_tensor.element
|
||||
self.element_compute = node.element_compute
|
||||
self.reg_reduce_fn = self.node.reg_reduce_fn
|
||||
self.gmem_reduce_fn = self.node.gmem_reduce_fn
|
||||
self.round_style = node.round_style
|
||||
self.stride_dtype = "int"
|
||||
|
||||
def get_reduce_identity(self):
|
||||
"""
|
||||
Return the reduction identity of the current reduce_fn
|
||||
"""
|
||||
maxes = {
|
||||
DataType.f32: (2 ** 31) - 1,
|
||||
DataType.f16: (2 ** 15),
|
||||
DataType.s32: (2 ** 31) - 1,
|
||||
DataType.s8: (2 ** 7) - 1
|
||||
}
|
||||
mins = {
|
||||
DataType.f32: -maxes[DataType.f32],
|
||||
DataType.f16: -maxes[DataType.f16],
|
||||
DataType.s32: -maxes[DataType.s32],
|
||||
DataType.s8: -maxes[DataType.s8]
|
||||
}
|
||||
if self.reg_reduce_fn == FunctionalOp.Maximum:
|
||||
if self.element_compute not in mins:
|
||||
raise Exception(f"No min entry for data type {self.element_compute}")
|
||||
return to_ctype_value(mins[self.element_compute], self.element_compute)
|
||||
elif self.reg_reduce_fn == FunctionalOp.Multiplies:
|
||||
return to_ctype_value(1., self.element_compute)
|
||||
elif self.reg_reduce_fn == FunctionalOp.Minimum:
|
||||
if self.element_compute not in maxes:
|
||||
raise Exception(f"No max entry for data type {self.element_compute}")
|
||||
return to_ctype_value(maxes[self.element_compute], self.element_compute)
|
||||
else:
|
||||
return to_ctype_value(0., self.element_compute)
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
self.get_reduce_identity()
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
element_compute = self.element_compute
|
||||
reduce_identity = self.get_reduce_identity()
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr", ctypes.c_void_p),
|
||||
("reduce_identity", dtype2ctype[element_compute]),
|
||||
("dMNL", tuple_type)
|
||||
]
|
||||
def __init__(self, kwargs) -> None:
|
||||
ptr = kwargs[name]
|
||||
self.ptr = ptr
|
||||
self.reduce_identity = reduce_identity
|
||||
self.dMNL = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
|
||||
class ColumnReductionImpl(ReductionImplBase):
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if not node.is_output:
|
||||
return False
|
||||
if node.name in StoreImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.store_tensor.stride[-2:]
|
||||
if strideMN == (1, 0):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class RowReductionImpl(ReductionImplBase):
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if not node.is_output:
|
||||
return False
|
||||
if node.name in StoreImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.store_tensor.stride[-2:]
|
||||
if strideMN == (0, 1):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class ScalarReductionImpl(ReductionImplBase):
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if not node.is_output:
|
||||
return False
|
||||
if node.name in StoreImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.store_tensor.stride[-2:]
|
||||
if strideMN == (0, 0):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class StoreNode(NodeBase):
|
||||
"""
|
||||
Store node
|
||||
"""
|
||||
possible_impls = [
|
||||
AuxStoreImpl, RowReductionImpl,
|
||||
ColumnReductionImpl, ScalarReductionImpl,
|
||||
NoOpImpl, StoreDImpl
|
||||
]
|
||||
def __init__(self, name: str) -> None:
|
||||
super().__init__(name)
|
||||
self.op = "store"
|
||||
self.is_output = False
|
||||
self._store_tensor = None
|
||||
|
||||
@property
|
||||
def store_tensor(self) -> Tensor:
|
||||
"""
|
||||
Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
|
||||
"""
|
||||
return self._store_tensor
|
||||
|
||||
@store_tensor.setter
|
||||
def store_tensor(self, kwargs):
|
||||
"""
|
||||
Setting the tensor
|
||||
"""
|
||||
self._store_tensor = Tensor(**kwargs)
|
||||
|
||||
def type_propagation(self, input_node_metas: 'list[NodeBase]'):
|
||||
"""
|
||||
The store nodes has element_output = element_input
|
||||
"""
|
||||
if self.is_output:
|
||||
if self.store_tensor is None:
|
||||
raise RuntimeError(f"The store tensor of node {self.name} is unknown.")
|
||||
self.element = self.store_tensor.element
|
||||
assert len(input_node_metas) == 1, "Store node can only have one input node"
|
||||
self.element_output = input_node_metas[0].element_output
|
||||
|
||||
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
|
||||
super().broadcast_propagation(input_node_metas)
|
||||
if self.is_output:
|
||||
self._store_tensor.broadcast(self.tensor.shape)
|
||||
137
python/cutlass_cppgen/backend/evt/ir/tensor.py
Normal file
137
python/cutlass_cppgen/backend/evt/ir/tensor.py
Normal file
@@ -0,0 +1,137 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
High-level class for tensor
|
||||
"""
|
||||
|
||||
from cutlass_library import LayoutType
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir.layout_algorithm import (
|
||||
Layout,
|
||||
broadcast,
|
||||
canonicalization,
|
||||
permutation,
|
||||
reshape,
|
||||
_reverse_tuple
|
||||
)
|
||||
from cutlass_cppgen.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type
|
||||
|
||||
|
||||
class Tensor:
|
||||
"""
|
||||
The tensor abstracts the data type
|
||||
"""
|
||||
def __init__(self, tensor=None, element=None, shape=None, stride=None,layout_tag=None, is_constant=False) -> None:
|
||||
if element is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both element and tensor")
|
||||
elif shape is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both shape and tensor")
|
||||
elif layout_tag is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both layout_tag and tensor")
|
||||
elif (element is None or (layout_tag is None and stride is None) or shape is None) and (tensor is None) :
|
||||
raise Exception(f"Must specify one of (element, shape, layout/stride) or (tensor)")
|
||||
elif stride is not None and tensor is not None:
|
||||
raise Exception(f"Must not specify both stride and tensor")
|
||||
elif stride is not None and layout_tag is not None:
|
||||
raise Exception(f"Must not specify layout_tag when stride is provided")
|
||||
|
||||
if isinstance(tensor, Tensor):
|
||||
# Directly copy all the attributes
|
||||
self.__dict__.update(vars(tensor))
|
||||
else:
|
||||
if tensor is None:
|
||||
self.element = library_type(element)
|
||||
else:
|
||||
self.element, layout_tag = get_datatype_and_layout(tensor)
|
||||
shape = get_tensor_shape(tensor)
|
||||
if stride is not None:
|
||||
self.layout = Layout(shape[::-1], stride[::-1])
|
||||
else:
|
||||
if layout_tag == LayoutType.RowMajor:
|
||||
self.layout = Layout(shape[::-1])
|
||||
elif layout_tag == LayoutType.ColumnMajor:
|
||||
self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
|
||||
self.layout = canonicalization(self.layout)
|
||||
|
||||
self.is_constant = is_constant
|
||||
# Save the tensor value if it is constant
|
||||
if is_constant and tensor is not None:
|
||||
self.value = tensor
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""
|
||||
Returns the RowMajor layout shape
|
||||
"""
|
||||
return _reverse_tuple(self.layout.shape)
|
||||
|
||||
@property
|
||||
def stride(self):
|
||||
"""
|
||||
Returns the RowMajor layout stride
|
||||
"""
|
||||
return _reverse_tuple(self.layout.stride)
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
"""
|
||||
Returns the rank of the tensor
|
||||
"""
|
||||
return len(self.shape)
|
||||
|
||||
#
|
||||
# Layout Algorithms
|
||||
#
|
||||
|
||||
def broadcast(self, shape):
|
||||
"""
|
||||
Broadcast self.layout to shape
|
||||
"""
|
||||
assert isinstance(shape, tuple)
|
||||
self.layout = broadcast(self.layout, _reverse_tuple(shape))
|
||||
|
||||
def reshape(self, shape):
|
||||
"""
|
||||
Reshape self.layout to shape
|
||||
"""
|
||||
assert isinstance(shape, tuple)
|
||||
reverse_shape = _reverse_tuple(shape)
|
||||
self.layout = reshape(self.layout, reverse_shape)
|
||||
|
||||
def permute(self, indices):
|
||||
"""
|
||||
Permute self.layout according to indices
|
||||
"""
|
||||
length = len(indices)
|
||||
indices = [length - idx - 1 for idx in indices]
|
||||
self.layout = permutation(self.layout, indices[::-1])
|
||||
42
python/cutlass_cppgen/backend/evt/passes/__init__.py
Normal file
42
python/cutlass_cppgen/backend/evt/passes/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_cppgen.backend.evt.passes.graph_drawer import EVTGraphDrawer
|
||||
from cutlass_cppgen.backend.evt.passes.pass_argument_type import PassGetArgumentType
|
||||
from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
|
||||
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
|
||||
from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD
|
||||
from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassManager
|
||||
from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
|
||||
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
from cutlass_cppgen.backend.evt.passes.smem_size_calculator import GetSmemSize
|
||||
143
python/cutlass_cppgen/backend/evt/passes/graph_drawer.py
Normal file
143
python/cutlass_cppgen/backend/evt/passes/graph_drawer.py
Normal file
@@ -0,0 +1,143 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
|
||||
from cutlass_library import DataTypeTag
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR
|
||||
|
||||
|
||||
_COLOR_MAP = {
|
||||
"load": '"AliceBlue"',
|
||||
"compute": "LemonChiffon1",
|
||||
"accumulator": "LightGrey",
|
||||
"store": "PowderBlue",
|
||||
"layout": "lightseagreen",
|
||||
"dag": "darkorange"
|
||||
}
|
||||
|
||||
|
||||
class EVTGraphDrawer:
|
||||
"""
|
||||
Visualize a EVT DAGIR with graphviz
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
graph: DAGIR,
|
||||
name: str
|
||||
):
|
||||
self._name = name
|
||||
self._dot_graphs = {}
|
||||
|
||||
self._dot_graphs[name] = self._to_dot(graph, name)
|
||||
|
||||
def _get_node_style(self, node):
|
||||
template = {
|
||||
"shape": "record",
|
||||
"fillcolor": "#CAFFE3",
|
||||
"style": '"filled,rounded"',
|
||||
"fontcolor": "#000000",
|
||||
}
|
||||
if node.op in _COLOR_MAP:
|
||||
template["fillcolor"] = _COLOR_MAP[node.op]
|
||||
else:
|
||||
raise NotImplementedError("unknown node op")
|
||||
if node.disabled:
|
||||
template["fontcolor"] = "grey"
|
||||
template["fillcolor"] = "white"
|
||||
return template
|
||||
|
||||
def _get_node_label(self, node):
|
||||
label = "{" + f"name={node.name}|op={node.op}"
|
||||
if node.op == "layout":
|
||||
label += f"|fn={node.fn.__name__}"
|
||||
for key in node.kwargs:
|
||||
label += f"|{key}={node.kwargs[key]}"
|
||||
if node.underlying_impl is not None:
|
||||
label += f"|impl={type(node.underlying_impl).__name__}"
|
||||
if node.op == "load":
|
||||
label += f"|element_output={DataTypeTag[node.underlying_impl.element]}"
|
||||
elif node.op == "compute":
|
||||
label += f"|element_compute={DataTypeTag[node.underlying_impl.element_compute]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
|
||||
elif node.op == "store":
|
||||
label += f"|element_store={DataTypeTag[node.underlying_impl.element]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
|
||||
elif node.op == "dag":
|
||||
label += f"|element_output={DataTypeTag[node.underlying_impl.element_output]}"
|
||||
if node.tensor is not None:
|
||||
shape = node.tensor.shape
|
||||
stride = node.tensor.stride
|
||||
label += f"|shape={shape}|stride={stride}"
|
||||
|
||||
if hasattr(node, "store_tensor"):
|
||||
if node.store_tensor is not None:
|
||||
store_shape = node.store_tensor.shape
|
||||
store_stride = node.store_tensor.stride
|
||||
label += f"|store_shape={store_shape}|stride_stride={store_stride}"
|
||||
|
||||
label += "}"
|
||||
return label
|
||||
|
||||
def _to_dot(
|
||||
self,
|
||||
graph: DAGIR,
|
||||
name: str
|
||||
):
|
||||
import pydot
|
||||
dot_graph = pydot.Dot(name, randir="TB")
|
||||
for node in graph.nodes_meta:
|
||||
style = self._get_node_style(node)
|
||||
label = self._get_node_label(node)
|
||||
dot_node = pydot.Node(
|
||||
node.name, label=label, **style
|
||||
)
|
||||
dot_graph.add_node(dot_node)
|
||||
if node.op == "dag":
|
||||
dot_subgraph = self._to_dot(node.subgraph, name=node.name)
|
||||
self._dot_graphs[node.name] = dot_subgraph
|
||||
|
||||
# Add edges
|
||||
for src, dst in graph.edges:
|
||||
weight = graph.get_edge_weight(src, dst)
|
||||
dot_graph.add_edge(pydot.Edge(src, dst, label=weight))
|
||||
|
||||
return dot_graph
|
||||
|
||||
def get_dot_graph(self) -> pydot.Dot:
|
||||
return [(key, self.get_dot_graph_by_name(key)) for key in self._dot_graphs.keys()]
|
||||
|
||||
def get_dot_graph_by_name(self, name) -> pydot.Dot:
|
||||
return self._dot_graphs[name]
|
||||
|
||||
def get_main_dot_graph(self) -> pydot.Dot:
|
||||
return self._dot_graphs[self._name]
|
||||
120
python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py
Normal file
120
python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py
Normal file
@@ -0,0 +1,120 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Construct the epilogue visitor argument type
|
||||
"""
|
||||
|
||||
from cutlass_cppgen.backend.c_types import visitor_factory
|
||||
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode
|
||||
from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
|
||||
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
||||
|
||||
|
||||
class PassGetArgumentType(EVTPassBase):
|
||||
"""
|
||||
Construct the epilogue visitor argument type
|
||||
"""
|
||||
dependencies = [
|
||||
PassShapeTypePropagation, # The Layout of all nodes must be set
|
||||
PassDAG2Tree, # The type of each node must be set
|
||||
PassGetImpl # The DAG subgraphs must be set
|
||||
]
|
||||
|
||||
def requires(self) -> None:
|
||||
# Check "D" is in the node list
|
||||
if cc_map[self.cc] in [90, 100] and (not self.dag_ir.has_node("D")):
|
||||
raise SyntaxError(
|
||||
"Sm90+ EVT requires the epilogue to have a returned tensor D, "
|
||||
"but the variable 'D' is not found in the return values.")
|
||||
|
||||
def call(self):
|
||||
nodes = self.dag_ir.nodes_topological_order()
|
||||
self.argument_types = {}
|
||||
for node in nodes:
|
||||
meta = self.dag_ir.get_node_meta(node)
|
||||
if not meta.disabled:
|
||||
self.argument_types[node] = meta.underlying_impl.argument_type
|
||||
if node == "D" and cc_map[self.cc] in [90, 100]:
|
||||
continue
|
||||
if isinstance(meta, TopoVisitorNode):
|
||||
self.get_dag_argument_type(node)
|
||||
else:
|
||||
self.get_evt_argument_type(node)
|
||||
|
||||
self.cc_specific_method(self.set_argument_type)()
|
||||
|
||||
def get_evt_argument_type(self, node):
|
||||
# Sort the input nodes by edge weight
|
||||
input_types = [self.argument_types[child] for child in self.dag_ir.get_all_inputs(node)]
|
||||
if len(input_types) > 0:
|
||||
self.argument_types[node] = visitor_factory(
|
||||
input_types + [self.argument_types[node],], self.dag_ir.get_all_inputs(node) + [node,])
|
||||
|
||||
def get_dag_argument_type(self, node):
|
||||
meta = self.dag_ir.get_node_meta(node)
|
||||
subgraph = meta.subgraph
|
||||
subgraph_nodes = subgraph.nodes_topological_order()
|
||||
# Visit the unvisited nodes in subgraph
|
||||
for n in subgraph_nodes:
|
||||
m = subgraph.get_node_meta(n)
|
||||
if m.disabled:
|
||||
continue
|
||||
else:
|
||||
self.argument_types[n] = m.underlying_impl.argument_type
|
||||
input_types = [self.argument_types[child] for child in subgraph_nodes[:-1]]
|
||||
if len(input_types) > 0:
|
||||
self.argument_types[node] = visitor_factory(input_types, subgraph_nodes[:-1])
|
||||
|
||||
def set_argument_type(self):
|
||||
pass
|
||||
|
||||
def sm90_set_argument_type(self):
|
||||
self.dag_ir.epilogue_thread_type = self.argument_types[self.dag_ir.get_all_inputs("D")[0]]
|
||||
# Get the tensorD argument type
|
||||
self.dag_ir.arg_d_type = self.dag_ir.get_node_meta("D").underlying_impl.argument_type_d
|
||||
|
||||
# Get the tensorC argument type
|
||||
if self.dag_ir.has_node("C"):
|
||||
self.dag_ir.arg_c_type = self.dag_ir.get_node_meta("C").underlying_impl.argument_type_c
|
||||
else:
|
||||
self.dag_ir.arg_c_type = self.dag_ir.arg_d_type
|
||||
|
||||
def sm100_set_argument_type(self):
|
||||
self.sm90_set_argument_type()
|
||||
|
||||
def sm80_set_argument_type(self):
|
||||
nodes = self.dag_ir.nodes_topological_order()
|
||||
self.dag_ir.epilogue_thread_type = self.argument_types[nodes[-1]]
|
||||
169
python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py
Normal file
169
python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py
Normal file
@@ -0,0 +1,169 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Merge non-tree sub-graphs of the DAG IR into a single DAG. The fused DAG will be implemented
|
||||
by the topological visitor, while the rest of the graph will be implemented with the tree visitor.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir import DAGIR, TopoVisitorNode
|
||||
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
|
||||
|
||||
class PassDAG2Tree(EVTPassBase):
|
||||
"""
|
||||
Convert the DAG IR to Tree by fusing subgraphs
|
||||
"""
|
||||
dependencies = [
|
||||
PassShapeTypePropagation,
|
||||
PassGetImpl
|
||||
]
|
||||
|
||||
def call(self):
|
||||
# Step 1: find the nodes that have multiple parents
|
||||
multi_parent_nodes = []
|
||||
|
||||
for node in self.dag_ir.nodes_topological_order():
|
||||
if self.dag_ir.out_degree(node) > 1:
|
||||
multi_parent_nodes.append(node)
|
||||
# Step 2: find the lowest common ancestor (LCA) of all its parents
|
||||
for node in multi_parent_nodes:
|
||||
# A multi-parent node could be already fused by the previous node
|
||||
if not self.dag_ir.has_node(node):
|
||||
continue
|
||||
# A node uncovered by the previous fusions can have out degree change
|
||||
# Case 1: it has <= 1 edges to the previously fused subgraph, no degree change
|
||||
# Case 2: it has more than one edges to the previously fused subgraph, degree drops
|
||||
if self.dag_ir.out_degree(node) <= 1:
|
||||
continue
|
||||
|
||||
# Otherwise, the node still
|
||||
reachable_nodes = []
|
||||
# Complexity: O(Dout*N)
|
||||
for parent in self.dag_ir.get_users(node):
|
||||
reachable_nodes.append(set(self.dag_ir.all_reachable_nodes(parent)))
|
||||
# get the common reachable objects
|
||||
common_items = set.intersection(*reachable_nodes)
|
||||
node_to_fuse = set.union(*reachable_nodes).difference(common_items)
|
||||
|
||||
lca = None
|
||||
# If common ancestor exists, find the lowest one
|
||||
if len(common_items) > 0:
|
||||
topo_order = self.dag_ir.nodes_topological_order()
|
||||
topo_idx = -1
|
||||
for item in common_items:
|
||||
if lca is None:
|
||||
lca = item
|
||||
topo_idx = topo_order.index(item)
|
||||
else:
|
||||
if topo_idx > topo_order.index(item):
|
||||
lca = item
|
||||
topo_idx = topo_order.index(item)
|
||||
else:
|
||||
# there is no common ancestor for all the parents, we pack all the reachable
|
||||
# nodes into a single DAG node as a fallback. The lca should be the input node of
|
||||
# one of the output nodes with out_degree = 0
|
||||
potential_output_nodes = []
|
||||
for node in node_to_fuse:
|
||||
if self.dag_ir.out_degree(node) == 0:
|
||||
potential_output_nodes.append(node)
|
||||
if len(potential_output_nodes) == 0:
|
||||
raise RuntimeError(f"No output node with out degree = 0 found.")
|
||||
|
||||
output_node = None
|
||||
if (self.dag_ir.cc >= 90):
|
||||
# For SM90+, the lca should be the input node of D
|
||||
if (not self.dag_ir.has_node("D")):
|
||||
raise RuntimeError(f"D is not a node in the DAG IR.")
|
||||
output_node = "D"
|
||||
else:
|
||||
output_node = potential_output_nodes[0]
|
||||
|
||||
if (output_node is None):
|
||||
raise RuntimeError(f"No output node found.")
|
||||
lca = self.dag_ir.get_all_inputs(output_node)[0]
|
||||
node_to_fuse.remove(output_node)
|
||||
|
||||
# The lca is the output node of the DAG node
|
||||
# Get the nodes to be fused
|
||||
node_to_fuse.add(lca)
|
||||
# Get all the input nodes
|
||||
all_input_nodes = []
|
||||
all_output_nodes = []
|
||||
for node in node_to_fuse:
|
||||
all_input_nodes.append(set(self.dag_ir.get_all_inputs(node)))
|
||||
all_output_nodes.append(set(self.dag_ir.get_users(node)))
|
||||
all_input_nodes = set.union(*all_input_nodes)
|
||||
all_output_nodes = set.union(*all_output_nodes)
|
||||
|
||||
new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes)
|
||||
|
||||
# Create the subgraph
|
||||
subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
|
||||
subgraph = DAGIR(self.dag_ir.cc)
|
||||
for node in subgraph_.nodes:
|
||||
meta = deepcopy(self.dag_ir.get_node_meta(node))
|
||||
if node not in node_to_fuse:
|
||||
meta.disabled = True
|
||||
subgraph.add_node(meta)
|
||||
for edge in subgraph_.edges:
|
||||
subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1]))
|
||||
|
||||
|
||||
# Create the fused node
|
||||
dag_node = TopoVisitorNode(
|
||||
name=f"dag_{lca}", subgraph=subgraph,
|
||||
output_node=self.dag_ir.get_node_meta(lca))
|
||||
self.dag_ir.add_node(dag_node)
|
||||
|
||||
# Add input edges
|
||||
for idx, node in enumerate(all_input_nodes):
|
||||
self.dag_ir.add_edge(node, dag_node.name, weight=idx)
|
||||
|
||||
# Replace all uses with DAG node (only 1 output node)
|
||||
self.dag_ir.replace_all_uses_with(lca, dag_node.name)
|
||||
|
||||
# Remove all fused nodes
|
||||
node_to_fuse.remove(lca)
|
||||
for node in node_to_fuse:
|
||||
self.dag_ir.remove_node(node)
|
||||
|
||||
def ensures(self) -> None:
|
||||
# Ensure that after the pass, the resulting DAG becomes a tree
|
||||
for node in self.dag_ir.nodes:
|
||||
out_degree = self.dag_ir.out_degree(node)
|
||||
if out_degree > 1:
|
||||
raise RuntimeError(f"PassDAG2Tree failed. Node {node} still have outdegree = {out_degree}")
|
||||
@@ -0,0 +1,64 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Fix the element_output of producer of D.
|
||||
|
||||
In Sm90 epilogue visitor, the node writing D to gmem does not have internal
|
||||
element converter, so the compute node producing D must have element_output = type(D).
|
||||
"""
|
||||
|
||||
from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
|
||||
|
||||
class PassFixElementD(EVTPassBase):
|
||||
"""
|
||||
In Sm90 epilogue visitor, the node writing D to gmem does not have internal
|
||||
element converter, so the compute node producing D must have
|
||||
element_output = type(D)
|
||||
"""
|
||||
dependencies = [
|
||||
PassLayoutManipulateElimination
|
||||
]
|
||||
def get_producer(self, node, element_D):
|
||||
node_meta = self.dag_ir.get_node_meta(node)
|
||||
if node_meta.op == "compute":
|
||||
node_meta.element_output = element_D
|
||||
elif node_meta.op == "store":
|
||||
self.get_producer(self.dag_ir.get_all_inputs(node)[0], element_D)
|
||||
|
||||
def call(self):
|
||||
if self.dag_ir.has_node("D"):
|
||||
node_d_meta = self.dag_ir.get_node_meta("D")
|
||||
element_D = node_d_meta.store_tensor.element
|
||||
self.get_producer("D", element_D)
|
||||
90
python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py
Normal file
90
python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py
Normal file
@@ -0,0 +1,90 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Infer the underlying implement of each node.
|
||||
|
||||
While the frontend only distinguish between Load/Store/Compute Node,
|
||||
each of these nodes can have different underlying implementation based
|
||||
on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
|
||||
This pass infers the underlying impl of each node
|
||||
"""
|
||||
|
||||
import cutlass_cppgen.backend.evt.backend as evt_backend
|
||||
from cutlass_cppgen.backend.evt.ir import DAGIR, LoadNode
|
||||
from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass_cppgen.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination
|
||||
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
||||
|
||||
|
||||
class PassGetImpl(EVTPassBase):
|
||||
"""
|
||||
While the frontend only distinguish between Load/Store/Compute Node,
|
||||
each of these nodes can have different underlying implementation based
|
||||
on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
|
||||
This pass infers the underlying impl of each node
|
||||
"""
|
||||
dependencies = [
|
||||
PassShapeTypePropagation, # The shape and type info are required for inference
|
||||
PassFixElementD
|
||||
]
|
||||
|
||||
def __init__(self, dag_ir: DAGIR) -> None:
|
||||
super().__init__(dag_ir)
|
||||
self.no_op_elimination = PassNoOpElimination(dag_ir)
|
||||
|
||||
def requires(self) -> None:
|
||||
# Verify "accum" is in the arg list
|
||||
if not self.dag_ir.has_node("accum"):
|
||||
raise SyntaxError("Cannot find 'accum' in the argument list.")
|
||||
|
||||
def call(self):
|
||||
# The loop structure of the epilogue is determined by the
|
||||
# accumulator shape
|
||||
accumulator: LoadNode = self.dag_ir.get_node_meta("accum")
|
||||
problem_size = accumulator.tensor.shape
|
||||
|
||||
for node_meta in self.dag_ir.node_metas_topological_order():
|
||||
node_meta.get_underlying_impl(problem_size)
|
||||
|
||||
def ensures(self) -> None:
|
||||
# Some nodes will be lowered to NoOp, eliminate them
|
||||
self.no_op_elimination()
|
||||
# Lower to cc-specific impl
|
||||
for node_meta in self.dag_ir.nodes_meta:
|
||||
node_impl_ccs = getattr(evt_backend, f"sm{cc_map[self.cc]}_nodes")
|
||||
node_meta.underlying_impl = getattr(
|
||||
node_impl_ccs,
|
||||
f"Sm{cc_map[self.cc]}" + node_meta.underlying_impl.__class__.__name__
|
||||
)(node_meta)
|
||||
@@ -0,0 +1,217 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Eliminate layout manipulation nodes
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir import DAGIR, LayoutNode
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
|
||||
|
||||
class PassLayoutManipulateElimination(EVTPassBase):
|
||||
"""
|
||||
Eliminate layout manipulation nodes
|
||||
"""
|
||||
dependencies = [PassShapeTypePropagation]
|
||||
|
||||
def __init__(self, dag_ir: DAGIR) -> None:
|
||||
super().__init__(dag_ir)
|
||||
self.copy_cnt = 0
|
||||
|
||||
def call(self):
|
||||
self.layout_nodes_worklist = self.get_all_layout_nodes()
|
||||
# Run while loop utill all layout nodes are eliminated
|
||||
while(len(self.layout_nodes_worklist) > 0):
|
||||
node = self.layout_nodes_worklist.pop(0)
|
||||
# for node in layout_nodes:
|
||||
# Step 1: get the propagation direction
|
||||
direction = self.get_propagation_direction(node)
|
||||
self.visited = []
|
||||
getattr(self, f"propagate_to_{direction}")(self.dag_ir.get_node_meta(node), node)
|
||||
# Eliminate the current node
|
||||
input_node = self.dag_ir.get_all_inputs(node)[0]
|
||||
self.dag_ir.replace_all_uses_with(node, input_node)
|
||||
# layout_nodes = self.get_all_layout_nodes()
|
||||
|
||||
def get_all_layout_nodes(self):
|
||||
layout_nodes = []
|
||||
for node_meta in reversed(self.dag_ir.node_metas_topological_order()):
|
||||
if isinstance(node_meta, LayoutNode):
|
||||
layout_nodes.append(node_meta.name)
|
||||
return layout_nodes
|
||||
|
||||
def get_propagation_direction(self, node: str):
|
||||
"""
|
||||
The logic is propagating all layout nodes away from the accumulator node.
|
||||
"""
|
||||
self.visited = []
|
||||
self.get_influenced_users(node)
|
||||
nodes_influenced_dir_users = self.visited
|
||||
self.visited = []
|
||||
self.get_influenced_inputs(node)
|
||||
nodes_influenced_dir_inputs = self.visited
|
||||
|
||||
if "accum" in nodes_influenced_dir_users and "accum" not in nodes_influenced_dir_inputs:
|
||||
return "inputs"
|
||||
elif "accum" not in nodes_influenced_dir_users and "accum" in nodes_influenced_dir_inputs:
|
||||
return "users"
|
||||
else:
|
||||
raise RuntimeError("Unsolved propagation direction")
|
||||
|
||||
# Get all influenced nodes if we propagate along the user direction
|
||||
def get_influenced_users(self, node: str):
|
||||
if node in self.visited:
|
||||
return
|
||||
self.visited.append(node)
|
||||
|
||||
users = self.dag_ir.get_users(node)
|
||||
for user in users:
|
||||
self.get_influenced_users(user)
|
||||
user_inputs = []
|
||||
for user in users:
|
||||
user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
|
||||
if len(user_inputs) > 0:
|
||||
user_inputs = set.union(*user_inputs)
|
||||
user_inputs.remove(node)
|
||||
for input in user_inputs:
|
||||
self.get_influenced_inputs(input)
|
||||
|
||||
# Get all influenced nodes if we propagate along the input direction
|
||||
def get_influenced_inputs(self, node: str):
|
||||
if node in self.visited:
|
||||
return
|
||||
self.visited.append(node)
|
||||
|
||||
inputs = self.dag_ir.get_all_inputs(node)
|
||||
for input in inputs:
|
||||
self.get_influenced_inputs(input)
|
||||
input_users = []
|
||||
for input in inputs:
|
||||
input_users.append(set(self.dag_ir.get_users(input)))
|
||||
if len(input_users) > 0:
|
||||
input_users = set.union(*input_users)
|
||||
input_users.remove(node)
|
||||
for user in input_users:
|
||||
self.get_influenced_users(user)
|
||||
|
||||
def add_copy_before(self, layout_node_meta: LayoutNode, target: str):
|
||||
copied_node_meta = deepcopy(layout_node_meta)
|
||||
copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
|
||||
self.copy_cnt += 1
|
||||
copied_node_meta.name = copied_node
|
||||
self.dag_ir.add_node(copied_node_meta)
|
||||
# Add edges
|
||||
target_inputs = self.dag_ir.get_all_inputs(target)
|
||||
for src in target_inputs:
|
||||
self.dag_ir.remove_edge(src, target)
|
||||
self.dag_ir.add_edge(src, copied_node)
|
||||
self.dag_ir.add_edge(copied_node, target)
|
||||
self.layout_nodes_worklist.append(copied_node)
|
||||
|
||||
def add_copy_after(self, layout_node_meta: LayoutNode, target: str):
|
||||
copied_node_meta = deepcopy(layout_node_meta)
|
||||
copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
|
||||
self.copy_cnt += 1
|
||||
copied_node_meta.name = copied_node
|
||||
self.dag_ir.add_node(copied_node_meta)
|
||||
# Add edges
|
||||
users = self.dag_ir.get_users(target)
|
||||
for user in users:
|
||||
self.dag_ir.remove_edge(target, user)
|
||||
self.dag_ir.add_edge(copied_node, user)
|
||||
self.dag_ir.add_edge(target, copied_node)
|
||||
self.layout_nodes_worklist.append(copied_node)
|
||||
|
||||
# Propagate the layout `node` along the user direction
|
||||
def propagate_to_users(self, layout_node_meta: LayoutNode, node: str):
|
||||
"""
|
||||
Propagate layout node to users
|
||||
"""
|
||||
if node in self.visited:
|
||||
# Avoid applying twice
|
||||
return
|
||||
self.visited.append(node)
|
||||
|
||||
node_meta = self.dag_ir.get_node_meta(node)
|
||||
if layout_node_meta.name != node:
|
||||
if isinstance(node_meta, LayoutNode):
|
||||
# Layout node is not transparent with layout node
|
||||
self.add_copy_before(layout_node_meta, node)
|
||||
return
|
||||
else:
|
||||
layout_node_meta.apply_to_user(node_meta)
|
||||
|
||||
users = self.dag_ir.get_users(node)
|
||||
user_inputs = []
|
||||
for user in users:
|
||||
user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
|
||||
for user in users:
|
||||
self.propagate_to_users(layout_node_meta, user)
|
||||
if len(user_inputs) > 0:
|
||||
user_inputs = set.union(*user_inputs)
|
||||
user_inputs.remove(node)
|
||||
for input in user_inputs:
|
||||
self.propagate_to_inputs(layout_node_meta.get_inverse_node(), input)
|
||||
|
||||
# Propagate the layout `node` along the input direction
|
||||
def propagate_to_inputs(self, layout_node_meta: LayoutNode, node: str):
|
||||
"""
|
||||
Propagate layout node to inputs
|
||||
"""
|
||||
if node in self.visited:
|
||||
# Avoid applying twice
|
||||
return
|
||||
self.visited.append(node)
|
||||
|
||||
node_meta = self.dag_ir.get_node_meta(node)
|
||||
if layout_node_meta.name != node:
|
||||
if isinstance(node_meta, LayoutNode):
|
||||
# Layout node is not transparent with layout node
|
||||
self.add_copy_after(layout_node_meta, node)
|
||||
return
|
||||
else:
|
||||
layout_node_meta.apply_to_input(node_meta)
|
||||
inputs = self.dag_ir.get_all_inputs(node)
|
||||
input_users = []
|
||||
for input in inputs:
|
||||
input_users.append(set(self.dag_ir.get_users(input)))
|
||||
for input in inputs:
|
||||
self.propagate_to_inputs(layout_node_meta, input)
|
||||
if len(input_users) > 0:
|
||||
input_users = set.union(*input_users)
|
||||
input_users.remove(node)
|
||||
for user in input_users:
|
||||
self.propagate_to_users(layout_node_meta.get_inverse_node(), user)
|
||||
164
python/cutlass_cppgen/backend/evt/passes/pass_manager.py
Normal file
164
python/cutlass_cppgen/backend/evt/passes/pass_manager.py
Normal file
@@ -0,0 +1,164 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Pass manager for DAG IR.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir import DAGIR
|
||||
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
||||
|
||||
|
||||
class EVTPassBase:
|
||||
"""
|
||||
Base class for EVT Passes
|
||||
"""
|
||||
dependencies = []
|
||||
def __init__(self, dag_ir: DAGIR) -> None:
|
||||
self.dag_ir = dag_ir
|
||||
self.cc = self.dag_ir.cc
|
||||
|
||||
def requires(self) -> None:
|
||||
"""
|
||||
This function will be called before the pass is run.
|
||||
"""
|
||||
pass
|
||||
|
||||
def call(self) -> None:
|
||||
"""
|
||||
The pass that is run through the self.dag_ir
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"__call__ is not overwritten in Pass {self.__class__.__name__}")
|
||||
|
||||
def ensures(self) -> None:
|
||||
"""
|
||||
This function will be called after the pass is run.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __call__(self) -> Any:
|
||||
self.requires()
|
||||
self.call()
|
||||
self.ensures()
|
||||
|
||||
def cc_specific_method(self, func):
|
||||
"""
|
||||
This enables defining function that behaves differently under different cc
|
||||
The simplest example of using this function is the following
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
class ExamplePass(EVTPassBase):
|
||||
|
||||
def call(sekf):
|
||||
# This automatically select the smXX_func based on current cc
|
||||
self.cc_specific_method(self.func)()
|
||||
|
||||
# Interface func, can be empty
|
||||
def func(self):
|
||||
pass
|
||||
|
||||
# Sm90 specific func
|
||||
def sm90_func(self):
|
||||
// sm90 specific method
|
||||
return
|
||||
|
||||
# Sm80 specific func
|
||||
def sm80_func(self):
|
||||
// sm80 specific method
|
||||
return
|
||||
"""
|
||||
func_name = f"sm{cc_map[self.cc]}_{func.__name__}"
|
||||
if hasattr(self, func_name):
|
||||
return getattr(self, func_name)
|
||||
else:
|
||||
raise NotImplementedError(f"func {func.__name__} is not overwritten for Sm{self.cc}")
|
||||
|
||||
|
||||
class EVTPassManager(nx.DiGraph):
|
||||
"""
|
||||
Topological-based Pass Manager.
|
||||
Each registered pass has a list of dependencies. The pass manager organizes
|
||||
the passes as a DAG and launch the compiler passes under topological order.
|
||||
"""
|
||||
def __init__(self, dag_ir: DAGIR, pass_list):
|
||||
super().__init__()
|
||||
self.dag_ir = dag_ir
|
||||
for pass_cls in pass_list:
|
||||
self.add_pass(pass_cls)
|
||||
|
||||
self.sorted_passes = self.schedule()
|
||||
|
||||
def get_callable(self, pass_name):
|
||||
"""
|
||||
Return the callable of the pass
|
||||
"""
|
||||
return self.nodes[pass_name]["callable"]
|
||||
|
||||
def add_pass(self, pass_cls):
|
||||
"""
|
||||
Add a pass to the pass manager
|
||||
:param pass_cls: the class of pass
|
||||
:type pass_cls: derived class of EVTPassBase
|
||||
"""
|
||||
name = pass_cls.__name__
|
||||
pass_callable = pass_cls(self.dag_ir)
|
||||
self.add_node(name, callable=pass_callable)
|
||||
|
||||
def schedule(self):
|
||||
"""
|
||||
Schedule the added passes under topological order
|
||||
"""
|
||||
# Add edges
|
||||
for pass_name in self.nodes:
|
||||
callable = self.get_callable(pass_name)
|
||||
for dependency_cls in callable.dependencies:
|
||||
self.add_edge(
|
||||
dependency_cls.__name__,
|
||||
type(callable).__name__)
|
||||
|
||||
# Topological sort
|
||||
return list(nx.topological_sort(self))
|
||||
|
||||
def __call__(self) -> Any:
|
||||
"""
|
||||
Launch the registered passes
|
||||
"""
|
||||
for pass_name in self.sorted_passes:
|
||||
callable = self.get_callable(pass_name)
|
||||
callable()
|
||||
@@ -0,0 +1,53 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
No op elimination node
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir import NoOpImpl
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
|
||||
|
||||
class PassNoOpElimination(EVTPassBase):
|
||||
"""
|
||||
The dead node elimination pass removes nodes with NoOpImpl in DAG IR
|
||||
"""
|
||||
dependencies = []
|
||||
|
||||
def call(self) -> Any:
|
||||
for node in self.dag_ir.nodes_topological_order():
|
||||
node_meta = self.dag_ir.get_node_meta(node)
|
||||
if isinstance(node_meta.underlying_impl, NoOpImpl):
|
||||
self.dag_ir.replace_all_uses_with(node, self.dag_ir.get_all_inputs(node)[0])
|
||||
@@ -0,0 +1,97 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Preprocess the reduction nodes.
|
||||
|
||||
The parser treats reduction as Compute(op=(reg_reduce_fn, gmem_reduce_fn)) - Store()
|
||||
This pass fuses these into a single store node, and then replaces all uses of the
|
||||
current node with the new store node.
|
||||
"""
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir import ComputeNode, StoreNode
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
|
||||
|
||||
class PassPreprocessRed(EVTPassBase):
|
||||
"""
|
||||
Preprocess red nodes
|
||||
"""
|
||||
|
||||
def call(self):
|
||||
# Step 1: find the compute nodes with op=red
|
||||
red_compute_nodes = []
|
||||
for node_meta in self.dag_ir.nodes_meta:
|
||||
if isinstance(node_meta, ComputeNode):
|
||||
if type(node_meta.fn) == tuple:
|
||||
# To keep the frontend simple, the reduction nodes
|
||||
# are parsed into compute nodes by default
|
||||
# The simple heuristic to distinguish between compute
|
||||
# and reduction node is that compute node is a single function,
|
||||
# while the reduction node is a tuple of functions for
|
||||
# in-register reduction and atomic global memory reduction
|
||||
red_compute_nodes.append(node_meta.name)
|
||||
|
||||
# Step 2: for each compute, merge it with the succeeding store
|
||||
for node in red_compute_nodes:
|
||||
# Verify
|
||||
users = self.dag_ir.get_users(node)
|
||||
inputs = self.dag_ir.get_all_inputs(node)
|
||||
# Has a single user
|
||||
assert len(users) == 1
|
||||
assert len(inputs) == 1
|
||||
user = users[0]
|
||||
input = inputs[0]
|
||||
|
||||
user_meta = self.dag_ir.get_node_meta(user)
|
||||
# Must be a store node
|
||||
assert isinstance(user_meta, StoreNode)
|
||||
# With output degree == 0
|
||||
assert self.dag_ir.out_degree(user) == 0
|
||||
# Register the reduce op
|
||||
node_meta = self.dag_ir.get_node_meta(node)
|
||||
user_meta.reg_reduce_fn, user_meta.gmem_reduce_fn = node_meta.fn
|
||||
user_meta.element_compute = node_meta.element_compute
|
||||
user_meta.round_style = node_meta.round_style
|
||||
|
||||
# Replace all uses
|
||||
self.dag_ir.remove_edge(input, node)
|
||||
input_users = self.dag_ir.get_users(input)
|
||||
for iu in input_users:
|
||||
weight = self.dag_ir.get_edge_weight(input, iu)
|
||||
self.dag_ir.add_edge(user, iu, weight)
|
||||
self.dag_ir.remove_edge(input, iu)
|
||||
self.dag_ir.add_edge(input, user)
|
||||
self.dag_ir.remove_node(node)
|
||||
|
||||
# Register the reduction name
|
||||
self.dag_ir.reduction_names.append(user)
|
||||
@@ -0,0 +1,59 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Shape and type propagation pass
|
||||
"""
|
||||
|
||||
from cutlass_cppgen.backend.evt.ir.node import NodeBase
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
|
||||
|
||||
|
||||
class PassShapeTypePropagation(EVTPassBase):
|
||||
"""
|
||||
Propagate the shape and type of all nodes
|
||||
"""
|
||||
dependencies = [PassPreprocessRed]
|
||||
|
||||
def call(self):
|
||||
# Propagate the node shape and type
|
||||
for node in self.dag_ir.nodes_topological_order():
|
||||
node_meta: NodeBase = self.dag_ir.get_node_meta(node)
|
||||
input_node_metas = self.dag_ir.get_all_inputs_meta(node)
|
||||
node_meta.type_propagation(input_node_metas)
|
||||
node_meta.shape_propagation(input_node_metas)
|
||||
|
||||
for node in reversed(self.dag_ir.nodes_topological_order()):
|
||||
node_meta: NodeBase = self.dag_ir.get_node_meta(node)
|
||||
input_node_metas = self.dag_ir.get_all_inputs_meta(node)
|
||||
node_meta.broadcast_propagation(input_node_metas)
|
||||
319
python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py
Normal file
319
python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py
Normal file
@@ -0,0 +1,319 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Compute the shared memory size in bytes
|
||||
"""
|
||||
|
||||
from math import gcd
|
||||
|
||||
import cutlass_library
|
||||
from pycute import flatten, shape_div, product
|
||||
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR
|
||||
from cutlass_cppgen.backend.library import DataType, DataTypeSize
|
||||
|
||||
|
||||
class GetSmemSize:
|
||||
"""
|
||||
Get the size in byte of shared memory used by the kernel
|
||||
"""
|
||||
def __init__(self, dag_ir: DAGIR) -> None:
|
||||
self.dag_ir = dag_ir
|
||||
self.cc = self.dag_ir.cc
|
||||
|
||||
#
|
||||
# Sm90 epilogue specific
|
||||
#
|
||||
|
||||
def sm90_epilogue_tile(self, tile_description):
|
||||
# Get the epilogue tile size
|
||||
schedule = tile_description.epilogue_schedule
|
||||
if schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecialized:
|
||||
element_d = self.dag_ir.get_node_meta("D").element
|
||||
nperf = 64 if (DataTypeSize[element_d] == 8 and tile_description.threadblock_shape[1] % 64 == 0) else 32
|
||||
epi_tile_m = min(64, tile_description.threadblock_shape[0])
|
||||
epi_tile_n = gcd(min(nperf, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1])
|
||||
epilogue_tile_mn = (epi_tile_m, epi_tile_n)
|
||||
elif schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecializedCooperative:
|
||||
epi_tile_m = min(128, tile_description.threadblock_shape[0])
|
||||
epi_tile_n = gcd(min(32, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1])
|
||||
epilogue_tile_mn = (epi_tile_m, epi_tile_n)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported schedule: {schedule}")
|
||||
|
||||
# Get the pipeline stages
|
||||
stages_d = 2
|
||||
epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn))
|
||||
if self.dag_ir.has_node("C"):
|
||||
element_c = self.dag_ir.get_node_meta("C").element
|
||||
else:
|
||||
element_c = None
|
||||
|
||||
element_d = self.dag_ir.get_node_meta("D").element
|
||||
if element_c == element_d:
|
||||
reuse_smem_c = True
|
||||
else:
|
||||
reuse_smem_c = False
|
||||
stages_c = max(epi_tiles, stages_d + 1) if reuse_smem_c else epi_tiles
|
||||
|
||||
# Record the epilogue tile
|
||||
self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
|
||||
self.epilogue_tile_mn = epilogue_tile_mn
|
||||
self.epi_tiles = epi_tiles
|
||||
self.stages_c = stages_c
|
||||
self.stages_d = stages_d
|
||||
self.reuse_smem_c = reuse_smem_c
|
||||
self.element_c = element_c
|
||||
self.element_d = element_d
|
||||
self.is_source_supported = element_c is not None
|
||||
|
||||
def sm90_or_sm100_epilogue_smem_size(self, tile_description):
|
||||
# Get the Fusion Storage
|
||||
nodes = self.dag_ir.nodes_topological_order()
|
||||
self.smem_types = {}
|
||||
for node in nodes:
|
||||
meta = self.dag_ir.get_node_meta(node)
|
||||
if not meta.disabled:
|
||||
self.smem_types[node] = meta.underlying_impl.get_smem_size(
|
||||
self.cta_tile_mnk, self.epilogue_tile_mn,
|
||||
self.stages_c, self.stages_d, self.epi_tiles)
|
||||
if node == "D":
|
||||
continue
|
||||
if isinstance(meta, TopoVisitorNode):
|
||||
self.get_dag_smem_type(node)
|
||||
else:
|
||||
self.get_evt_smem_type(node)
|
||||
|
||||
thread_smem_size = self.smem_types[self.dag_ir.get_all_inputs("D")[0]][0]
|
||||
# Get the Tensor Storage
|
||||
tensors = []
|
||||
if self.is_source_supported:
|
||||
smem_C = DataTypeSize[self.element_c] * product(self.epilogue_tile_mn) * self.stages_c // 8
|
||||
tensors.append((smem_C, 128))
|
||||
else:
|
||||
tensors.append((0, 1))
|
||||
if self.reuse_smem_c:
|
||||
tensors.append((0, 128))
|
||||
else:
|
||||
smem_D = DataTypeSize[self.element_d] * product(self.epilogue_tile_mn) * self.stages_d // 8
|
||||
tensors.append((smem_D, 128))
|
||||
tensors.append((thread_smem_size, 128))
|
||||
|
||||
tensor_smem_size = self.get_struct_size(tensors)
|
||||
# Get pipeline storage size
|
||||
# sizeof(uint64_t * stages_c * 2), alignment of uint64_t
|
||||
# 2 is for FullBarrier and EmptyBarrier
|
||||
pipeline_smem_size = (8 * self.stages_c * 2, 8)
|
||||
|
||||
# get SharedStorage size
|
||||
smem_size = self.get_struct_size([tensor_smem_size, pipeline_smem_size])
|
||||
return smem_size[0]
|
||||
|
||||
def sm90_epilogue_smem_size(self, tile_description):
|
||||
"""
|
||||
Compute the shared memory size of sm90 collective epilogue
|
||||
"""
|
||||
self.sm90_epilogue_tile(tile_description)
|
||||
return self.sm90_or_sm100_epilogue_smem_size(tile_description)
|
||||
|
||||
#
|
||||
# Sm100 epilogue specific
|
||||
#
|
||||
|
||||
def sm100_epilogue_tile(self, tile_description):
|
||||
cta_tile = (tile_description.blackwell_threadblock_shape[0], tile_description.blackwell_threadblock_shape[1])
|
||||
mma_tile = cta_tile
|
||||
|
||||
if tile_description.is_2sm:
|
||||
cta_tile = (cta_tile[0] // 2, cta_tile[1])
|
||||
|
||||
if tile_description.is_2sm and mma_tile[0] == 128:
|
||||
tmem_warps = (2, 2)
|
||||
else:
|
||||
tmem_warps = (4, 1)
|
||||
|
||||
if self.dag_ir.has_node("C"):
|
||||
element_c = self.dag_ir.get_node_meta("C").element
|
||||
element_c_size = DataTypeSize[element_c]
|
||||
else:
|
||||
element_c = None
|
||||
element_c_size = 0
|
||||
|
||||
element_d = self.dag_ir.get_node_meta("D").element
|
||||
|
||||
DisableSource = element_c is None or not self.dag_ir.has_node("C") or self.dag_ir.get_node_meta("C").element == DataType.void
|
||||
|
||||
CtaM = cta_tile[0]
|
||||
CtaN = cta_tile[1]
|
||||
WarpM = tmem_warps[0]
|
||||
WarpN = tmem_warps[1]
|
||||
MaxBits = max(element_c_size, DataTypeSize[element_d])
|
||||
DpFull = 32
|
||||
M = min(CtaM, DpFull * WarpM)
|
||||
|
||||
if DisableSource:
|
||||
# Epilogues w/o residual load are less sensitive to smem allocation
|
||||
# Target a fixed amount of compute per epilogue iteration
|
||||
if MaxBits == 4:
|
||||
# Make epilogue tile larger to reduce the epilogue iterations.
|
||||
# 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same.
|
||||
ComputeElts = 8192
|
||||
Nperf = ComputeElts // M
|
||||
else:
|
||||
ComputeElts = 4096
|
||||
Nperf = ComputeElts // M
|
||||
else:
|
||||
# Epilogues w/ residual load are more sensitive to smem allocation
|
||||
# Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize
|
||||
if MaxBits == 32:
|
||||
Nperf = 16 if CtaM > 64 and CtaN <= 128 else 32
|
||||
elif MaxBits == 16:
|
||||
Nperf = 32 if CtaN <= 128 else 64
|
||||
else:
|
||||
Nperf = 64
|
||||
|
||||
def is_m_major(layout):
|
||||
return flatten(layout.stride[0]) == 1
|
||||
|
||||
if DisableSource or is_m_major(self.dag_ir.get_node_meta("C").tensor.layout):
|
||||
N_min_C = 8 * WarpN
|
||||
elif element_c_size == 6:
|
||||
N_min_C = 128 * WarpN
|
||||
else:
|
||||
N_min_C = (128 // element_c_size) * WarpN
|
||||
|
||||
if is_m_major(self.dag_ir.get_node_meta("D").tensor.layout):
|
||||
N_min_D = 8 * WarpN
|
||||
elif DataTypeSize[element_d] == 6:
|
||||
N_min_D = 128 * WarpN
|
||||
else:
|
||||
N_min_D = (128 // DataTypeSize[element_d]) * WarpN
|
||||
|
||||
N = min(CtaN, max(Nperf, N_min_C, N_min_D))
|
||||
|
||||
tile_m = M
|
||||
tile_n_size = N // WarpN * WarpN
|
||||
|
||||
epilogue_tile_mn = (tile_m, tile_n_size)
|
||||
epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn))
|
||||
|
||||
stages_d = min(epi_tiles, 2)
|
||||
reuse_smem_c = (element_c_size > 8)
|
||||
|
||||
if reuse_smem_c:
|
||||
stages_c = max(min(epi_tiles, 4), stages_d + 1)
|
||||
else:
|
||||
stages_c = min(epi_tiles, 4)
|
||||
|
||||
# Record the epilogue tile
|
||||
self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
|
||||
self.epilogue_tile_mn = epilogue_tile_mn
|
||||
self.epi_tiles = epi_tiles
|
||||
self.stages_c = stages_c
|
||||
self.stages_d = stages_d
|
||||
self.reuse_smem_c = reuse_smem_c
|
||||
self.element_c = element_c
|
||||
self.element_d = element_d
|
||||
self.is_source_supported = not DisableSource
|
||||
|
||||
def sm100_epilogue_smem_size(self, tile_description):
|
||||
"""
|
||||
Compute the shared memory size of sm100 collective epilogue
|
||||
"""
|
||||
self.sm100_epilogue_tile(tile_description)
|
||||
return self.sm90_or_sm100_epilogue_smem_size(tile_description)
|
||||
|
||||
def __call__(self, tile_description):
|
||||
return getattr(self, f"sm{self.cc}_epilogue_smem_size")(tile_description)
|
||||
|
||||
#
|
||||
# Helper functions
|
||||
#
|
||||
|
||||
@staticmethod
|
||||
def get_visitor_size(members: list, ebo: bool):
|
||||
"""
|
||||
Get the size of struct in bytes
|
||||
"""
|
||||
offset = 0
|
||||
max_alignment = 1
|
||||
if len(members) > 0:
|
||||
# Get alignment
|
||||
for _, alignment in members:
|
||||
max_alignment = max(max_alignment, alignment)
|
||||
|
||||
for type_size, _ in members:
|
||||
if type_size != 0:
|
||||
offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
|
||||
if type_size == 0 and not ebo:
|
||||
offset += 1
|
||||
else:
|
||||
offset += type_size
|
||||
offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
|
||||
return (offset, max_alignment)
|
||||
else:
|
||||
# Struct size is at least 1
|
||||
return (1, 1)
|
||||
|
||||
def get_struct_size(self, members: list):
|
||||
"""
|
||||
Get the size of struct in bytes
|
||||
"""
|
||||
return self.get_visitor_size(members, False)
|
||||
|
||||
def get_evt_smem_type(self, node):
|
||||
# Sort the input nodes by edge weight
|
||||
input_types = [self.smem_types[child] for child in self.dag_ir.get_all_inputs(node)]
|
||||
input_types.append(self.smem_types[node])
|
||||
if len(input_types) > 1:
|
||||
ebo = len(input_types) > 4
|
||||
self.smem_types[node] = self.get_visitor_size(input_types, ebo)
|
||||
|
||||
def get_dag_smem_type(self, node):
|
||||
meta = self.dag_ir.get_node_meta(node)
|
||||
subgraph = meta.subgraph
|
||||
subgraph_nodes = subgraph.nodes_topological_order()
|
||||
# Visit the unvisited nodes in subgraph
|
||||
for n in subgraph_nodes:
|
||||
m = subgraph.get_node_meta(n)
|
||||
if m.disabled:
|
||||
continue
|
||||
else:
|
||||
self.smem_types[n] = m.underlying_impl.get_smem_size(
|
||||
self.cta_tile_mnk, self.epilogue_tile_mn,
|
||||
self.stages_c, self.stages_d, self.epi_tiles)
|
||||
input_types = [self.smem_types[child] for child in subgraph_nodes[:-1]]
|
||||
if len(input_types) > 0:
|
||||
ebo = len(input_types) > 4
|
||||
self.smem_types[node] = self.get_visitor_size(input_types, ebo)
|
||||
46
python/cutlass_cppgen/backend/evt/passes/util.py
Normal file
46
python/cutlass_cppgen/backend/evt/passes/util.py
Normal file
@@ -0,0 +1,46 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utilities for passes
|
||||
"""
|
||||
|
||||
# Map from the CC of the kernel to the EVT implementation that the CC targets
|
||||
cc_map = {
|
||||
80: 80,
|
||||
86: 80,
|
||||
89: 80,
|
||||
90: 90,
|
||||
100: 100,
|
||||
101: 100,
|
||||
103: 100,
|
||||
}
|
||||
Reference in New Issue
Block a user