mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 01:36:10 +00:00
Enhance Collective Check at MSCCLang (#511)
This commit is contained in:
153
python/mscclpp/language/collective_checker.py
Normal file
153
python/mscclpp/language/collective_checker.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import List
|
||||
from mscclpp.language.buffer import *
|
||||
from mscclpp.language.collectives import Collective
|
||||
from mscclpp.language.types import DataFormat, ChannelType, ChunkRef
|
||||
from mscclpp.language.ir import *
|
||||
|
||||
|
||||
class CollectiveChecker:
|
||||
def __init__(self):
|
||||
self.collective = None
|
||||
self.buffers = None
|
||||
|
||||
def apply_send(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size):
|
||||
src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index)
|
||||
dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index)
|
||||
sb = self.buffers[src][src_buffer]
|
||||
db = self.buffers[dst][dst_buffer]
|
||||
for i in range(size):
|
||||
db[dst_index + i] = sb[src_index + i]
|
||||
|
||||
def apply_reduce(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size):
|
||||
src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index)
|
||||
dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index)
|
||||
sb = self.buffers[src][src_buffer]
|
||||
db = self.buffers[dst][dst_buffer]
|
||||
for i in range(size):
|
||||
reduce_chunk = db[dst_index + i]
|
||||
sent_chunk = sb[src_index + i]
|
||||
db[dst_index + i] = reduce_chunk.reduce(dst, sent_chunk)
|
||||
|
||||
def check_buffer_exists(self, rank, name):
|
||||
if name not in self.buffers[rank]:
|
||||
self.buffers[rank][name] = BufferSlice(Buffer.scratch, name)
|
||||
|
||||
def _get_buffer_index(self, src_rank, remote_rank, buffer, index):
|
||||
if index == -1 and buffer == None:
|
||||
return src_rank.buffer, src_rank.index
|
||||
elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output:
|
||||
return buffer, self.buffers[remote_rank][buffer].instance_size()
|
||||
return buffer, index
|
||||
|
||||
def _put(self, src, dst, buffer=None, index=-1):
|
||||
self.check_buffer_exists(dst, buffer)
|
||||
buffer, index = self._get_buffer_index(src, dst, buffer, index)
|
||||
self.apply_send(src.rank, src.buffer, src.index, dst, buffer, index, src.size)
|
||||
|
||||
def _put_packet(
|
||||
self,
|
||||
src,
|
||||
dst,
|
||||
buffer=None,
|
||||
index=-1,
|
||||
sendtb=-1,
|
||||
src_format=DataFormat.raw,
|
||||
chan_type=ChannelType.memory,
|
||||
temp_buffer=None,
|
||||
temp_buffer_index=-1,
|
||||
):
|
||||
chunk_ref = src
|
||||
if chan_type == ChannelType.port and src_format == DataFormat.raw:
|
||||
self._copy(src.rank, temp_buffer, temp_buffer_index)
|
||||
self._put(chunk_ref, dst, buffer, index)
|
||||
|
||||
def _get(self, dst, src, buffer=None, index=-1):
|
||||
self.check_buffer_exists(src, buffer)
|
||||
buffer, index = self._get_buffer_index(dst, src, buffer, index)
|
||||
self.apply_send(src, buffer, index, dst.rank, dst.buffer, dst.index, dst.size)
|
||||
|
||||
def _copy(self, src, dst, buffer=None, index=-1):
|
||||
self.check_buffer_exists(dst, buffer)
|
||||
buffer, index = self._get_buffer_index(src, dst, buffer, index)
|
||||
self.apply_send(src.rank, src.buffer, src.index, dst, buffer, index, src.size)
|
||||
|
||||
def _reduce(self, src, other_chunkref):
|
||||
dst = src.rank
|
||||
src_rank = other_chunkref.rank
|
||||
self.apply_reduce(src_rank, other_chunkref.buffer, other_chunkref.index, dst, src.buffer, src.index, src.size)
|
||||
|
||||
def _group_load_reduce(self, src, other_chunkrefs: list):
|
||||
for other_chunkref in other_chunkrefs:
|
||||
src_chunkref = other_chunkref
|
||||
self.apply_reduce(
|
||||
src_chunkref.rank,
|
||||
src_chunkref.buffer,
|
||||
src_chunkref.index,
|
||||
src.rank,
|
||||
src.buffer,
|
||||
src.index,
|
||||
src.size,
|
||||
)
|
||||
|
||||
def _group_store(self, src, dsts: list, index=-1, buffer=None):
|
||||
for dst in dsts:
|
||||
self.check_buffer_exists(dst, buffer)
|
||||
|
||||
for dst in dsts:
|
||||
buffer, index = self._get_buffer_index(src, dst, buffer, index)
|
||||
self.apply_send(src.rank, src.buffer, src.index, dst, buffer, index, src.size)
|
||||
|
||||
def _execute(self, operations):
|
||||
for op in operations:
|
||||
if op.inst == Instruction.put:
|
||||
src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size)
|
||||
self._put(src, op.dst.rank, op.dst.buffer, op.dst.index)
|
||||
elif op.inst == Instruction.put_packet:
|
||||
src_format = op.extra.get("src_format")
|
||||
temp_buffer = op.extra.get("temp_buffer")
|
||||
temp_buffer_index = op.extra.get("temp_buffer_index")
|
||||
src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size)
|
||||
self._put_packet(
|
||||
src,
|
||||
op.dst.rank,
|
||||
op.dst.buffer,
|
||||
op.dst.index,
|
||||
sendtb=op.tb,
|
||||
src_format=src_format,
|
||||
chan_type=op.channel_type,
|
||||
temp_buffer=temp_buffer,
|
||||
temp_buffer_index=temp_buffer_index,
|
||||
)
|
||||
elif op.inst == Instruction.get:
|
||||
dst = ChunkRef(op.dst.rank, op.dst.buffer, op.dst.index, op.dst.size)
|
||||
self._get(dst, op.src.rank, op.src.buffer, op.src.index)
|
||||
elif op.inst == Instruction.copy:
|
||||
src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size)
|
||||
self._copy(src, op.dst.rank, op.dst.buffer, op.dst.index)
|
||||
elif op.inst == Instruction.copy_packet:
|
||||
src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size)
|
||||
self._copy(src, op.dst.rank, op.dst.buffer, op.dst.index)
|
||||
elif op.inst == Instruction.reduce:
|
||||
src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size)
|
||||
self._reduce(src, ChunkRef(op.dst.rank, op.dst.buffer, op.dst.index, op.dst.size))
|
||||
elif op.inst == Instruction.reduce_packet:
|
||||
src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size)
|
||||
self._reduce(src, ChunkRef(op.dst.rank, op.dst.buffer, op.dst.index, op.dst.size))
|
||||
elif op.inst == Instruction.group_load_reduce:
|
||||
src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size)
|
||||
self._group_load_reduce(src, other_chunkrefs=op.srcs)
|
||||
elif op.inst == Instruction.group_store:
|
||||
dsts = op.extra.get("dsts")
|
||||
index = op.extra.get("index")
|
||||
buffer = op.extra.get("buffer")
|
||||
src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size)
|
||||
self._group_store(src, dsts=dsts, index=index, buffer=buffer)
|
||||
|
||||
def check(self, collective: Collective, operations: List):
|
||||
self.collective = collective
|
||||
self.buffers = collective.init_buffers()
|
||||
self._execute(operations)
|
||||
return collective.check(self)
|
||||
@@ -8,6 +8,8 @@ from mscclpp.language.types import DataFormat, ChannelType, ChunkRef, Replicatio
|
||||
from mscclpp.language.ir import *
|
||||
from mscclpp.language.dag import DagOptimizer, DagLower, InstructionDAG
|
||||
from mscclpp.language.rank import Rank
|
||||
from mscclpp.language.topo_sort import OperationDependencyGraph
|
||||
from mscclpp.language.collective_checker import CollectiveChecker
|
||||
|
||||
_current_program = None
|
||||
|
||||
@@ -54,6 +56,8 @@ class MSCCLPPProgram:
|
||||
# Initialize the input buffers
|
||||
self.buffers = collective.init_buffers()
|
||||
self.instr_dag = InstructionDAG(self.num_ranks, self.buffers)
|
||||
self.op_dep_dag = OperationDependencyGraph()
|
||||
self.collective_checker = CollectiveChecker()
|
||||
self.ranks = []
|
||||
for r in range(self.num_ranks):
|
||||
self.ranks.append(Rank(r))
|
||||
@@ -98,17 +102,6 @@ class MSCCLPPProgram:
|
||||
for i in range(size):
|
||||
db[dst_index + i] = sb[src_index + i]
|
||||
|
||||
# Tracks a reduce operation on the buffers
|
||||
def apply_reduce(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size):
|
||||
src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index)
|
||||
dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index)
|
||||
sb = self.buffers[src][src_buffer]
|
||||
db = self.buffers[dst][dst_buffer]
|
||||
for i in range(size):
|
||||
reduce_chunk = db[dst_index + i]
|
||||
sent_chunk = sb[src_index + i]
|
||||
db[dst_index + i] = reduce_chunk.reduce(dst, sent_chunk)
|
||||
|
||||
def get_ref(self, rank, buffer, index, size):
|
||||
buffer, index = self.collective.get_buffer_index(rank, buffer, index)
|
||||
return Ref(rank, buffer, index, size, self)
|
||||
@@ -129,7 +122,7 @@ class MSCCLPPProgram:
|
||||
# Checks that all chunks that should be on each rank
|
||||
# are present in the output buffer.
|
||||
def check(self):
|
||||
return self.collective.check(self)
|
||||
return self.collective_checker.check(self.collective, self.op_dep_dag.get_execution_order())
|
||||
|
||||
# Lower program to MSCCLPP
|
||||
def lower(self):
|
||||
@@ -248,6 +241,9 @@ class Ref(ChunkRef):
|
||||
return dst_chunkref
|
||||
|
||||
def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.memory):
|
||||
op = Op(inst=Instruction.put, rank=self.rank, src=self, dst=ChunkRef(dst, buffer, index, self.size))
|
||||
self.prog.op_dep_dag.add_operation(op)
|
||||
|
||||
return self._put(dst, buffer, index, sendtb, DataFormat.raw, chan_type)
|
||||
|
||||
def put_packet(
|
||||
@@ -261,6 +257,16 @@ class Ref(ChunkRef):
|
||||
temp_buffer=None,
|
||||
temp_buffer_index=-1,
|
||||
):
|
||||
extra = {"src_format": src_format, "temp_buffer": temp_buffer, "temp_buffer_index": temp_buffer_index}
|
||||
op = Op(
|
||||
inst=Instruction.put_packet,
|
||||
rank=self.rank,
|
||||
src=self,
|
||||
dst=ChunkRef(dst, buffer, index, self.size),
|
||||
extra=extra,
|
||||
)
|
||||
self.prog.op_dep_dag.add_operation(op)
|
||||
|
||||
chunk_ref = self
|
||||
if chan_type == ChannelType.port and src_format == DataFormat.raw:
|
||||
assert temp_buffer is not None, "Need to specify a temporary buffer for port channels"
|
||||
@@ -270,6 +276,9 @@ class Ref(ChunkRef):
|
||||
return chunk_ref._put(dst, buffer, index, sendtb, src_format, chan_type, True)
|
||||
|
||||
def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.memory):
|
||||
op = Op(inst=Instruction.get, rank=self.rank, src=ChunkRef(src, buffer, index, self.size), dst=self)
|
||||
self.prog.op_dep_dag.add_operation(op)
|
||||
|
||||
self.prog.check_buffer_exists(src, buffer)
|
||||
sender = src
|
||||
receiver = self.rank
|
||||
@@ -285,6 +294,9 @@ class Ref(ChunkRef):
|
||||
# to infer the tb index from the instruction DAG Add a channel is define as (send_tb, src_buffer, recv_tb, dst_buffer, type).
|
||||
# Then we can use DAG info to reduce the number of channels.
|
||||
def signal(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.memory):
|
||||
op = Op(inst=Instruction.signal, rank=self.rank, src=self, dst=ChunkRef(dst, buffer, index, self.size))
|
||||
self.prog.op_dep_dag.add_operation(op)
|
||||
|
||||
sender = self.rank
|
||||
receiver = dst
|
||||
assert sender != receiver, "Cannot signal to the same rank"
|
||||
@@ -295,6 +307,9 @@ class Ref(ChunkRef):
|
||||
|
||||
# only port channel need to use this function
|
||||
def flush(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.port):
|
||||
op = Op(inst=Instruction.flush, rank=self.rank, src=self, dst=ChunkRef(dst, buffer, index, self.size))
|
||||
self.prog.op_dep_dag.add_operation(op)
|
||||
|
||||
assert chan_type == ChannelType.port, "Only port channel can use flush"
|
||||
sender = self.rank
|
||||
receiver = dst
|
||||
@@ -305,6 +320,9 @@ class Ref(ChunkRef):
|
||||
self.prog.instr_dag.add_flush(sender, self, dst_chunkref, sendtb)
|
||||
|
||||
def wait(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.memory):
|
||||
op = Op(inst=Instruction.wait, rank=self.rank, src=ChunkRef(src, buffer, index, self.size), dst=self)
|
||||
self.prog.op_dep_dag.add_operation(op)
|
||||
|
||||
sender = src
|
||||
receiver = self.rank
|
||||
assert sender != receiver, "Cannot wait on the same rank"
|
||||
@@ -330,18 +348,21 @@ class Ref(ChunkRef):
|
||||
|
||||
# Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index)
|
||||
def copy(self, dst, buffer=None, index=-1, sendtb=-1):
|
||||
op = Op(inst=Instruction.copy, rank=self.rank, src=self, dst=ChunkRef(dst, buffer, index, self.size))
|
||||
self.prog.op_dep_dag.add_operation(op)
|
||||
|
||||
return self._copy(dst, buffer, index, sendtb)
|
||||
|
||||
def copy_packet(self, dst, buffer=None, index=-1, sendtb=-1):
|
||||
op = Op(inst=Instruction.copy_packet, rank=self.rank, src=self, dst=ChunkRef(dst, buffer, index, self.size))
|
||||
self.prog.op_dep_dag.add_operation(op)
|
||||
|
||||
return self._copy(dst, buffer, index, sendtb, trans_from_packet=True, trans_to_packet=False)
|
||||
|
||||
def _reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.memory, use_packet=False):
|
||||
dst = self.rank
|
||||
src = other_chunkref.rank
|
||||
|
||||
self.prog.apply_reduce(
|
||||
src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size
|
||||
)
|
||||
if use_packet:
|
||||
assert src == dst, "Packet reduce only supports intra-rank communication"
|
||||
|
||||
@@ -354,10 +375,16 @@ class Ref(ChunkRef):
|
||||
|
||||
# Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref
|
||||
def reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.memory):
|
||||
op = Op(inst=Instruction.reduce, rank=self.rank, src=self, dst=other_chunkref)
|
||||
self.prog.op_dep_dag.add_operation(op)
|
||||
|
||||
return self._reduce(other_chunkref, recvtb, channel_type)
|
||||
|
||||
# Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref
|
||||
def reduce_packet(self, other_chunkref, recvtb=-1):
|
||||
op = Op(inst=Instruction.reduce_packet, rank=self.rank, src=self, dst=other_chunkref)
|
||||
self.prog.op_dep_dag.add_operation(op)
|
||||
|
||||
return self._reduce(other_chunkref, recvtb, use_packet=True)
|
||||
|
||||
# """
|
||||
@@ -366,6 +393,9 @@ class Ref(ChunkRef):
|
||||
# """
|
||||
# Reads the chunk(s) referenced by other_chunkref and reduce into the chunk referenced by this chunkref
|
||||
def group_load_reduce(self, other_chunkrefs: list, recvtb=-1, chan_type=ChannelType.nvls):
|
||||
op = Op(inst=Instruction.group_load_reduce, rank=self.rank, src=self, dst=None, srcs=other_chunkrefs)
|
||||
self.prog.op_dep_dag.add_operation(op)
|
||||
|
||||
assert (
|
||||
len(other_chunkrefs) > 0 and chan_type == ChannelType.nvls
|
||||
), "Group load reduce only supports nvls channel"
|
||||
@@ -377,21 +407,15 @@ class Ref(ChunkRef):
|
||||
assert self.buffer == other_chunkref.buffer, "Group load reduce only supports chunks with the same buffer"
|
||||
assert self.index == other_chunkref.index, "Group load reduce only supports chunks with the same index"
|
||||
|
||||
src_chunkref = other_chunkref
|
||||
self.prog.apply_reduce(
|
||||
src_chunkref.rank,
|
||||
src_chunkref.buffer,
|
||||
src_chunkref.index,
|
||||
self.rank,
|
||||
self.buffer,
|
||||
self.index,
|
||||
self.size,
|
||||
)
|
||||
self.prog.instr_dag.add_group_load_reduce(self.rank, other_chunkrefs, self, recvtb, chan_type)
|
||||
return self
|
||||
|
||||
# Copies the chunk(s) referenced by this chunkref onto other_chunkrefs
|
||||
def group_store(self, dsts: list, index=-1, buffer=None, sendtb=-1, chan_type=ChannelType.nvls):
|
||||
extra = {"dsts": dsts, "index": index, "buffer": buffer}
|
||||
op = Op(inst=Instruction.group_store, rank=self.rank, src=self, dst=None, extra=extra)
|
||||
self.prog.op_dep_dag.add_operation(op)
|
||||
|
||||
for dst in dsts:
|
||||
self.prog.check_buffer_exists(dst, buffer)
|
||||
assert index == -1 or self.index == index, "Group store only supports chunks with the same index"
|
||||
@@ -434,7 +458,7 @@ def chunk(rank, buffer, index, size=1) -> Ref:
|
||||
if buffer is Buffer.scratch:
|
||||
if buffer not in _curr().buffers[rank]:
|
||||
_curr().buffers[rank][buffer] = BufferSlice(Buffer.scratch, buffer)
|
||||
if index >= len(_curr().buffers[rank][buffer]):
|
||||
if index >= len(_curr().buffers[rank][buffer]) or _curr().buffers[rank][buffer][index] is None:
|
||||
_curr().buffers[rank][buffer][index] = ChunkRef(rank, buffer, index, size)
|
||||
|
||||
if _curr().buffers[rank][buffer][index] is None:
|
||||
|
||||
104
python/mscclpp/language/topo_sort.py
Normal file
104
python/mscclpp/language/topo_sort.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from mscclpp.language.types import *
|
||||
from queue import Queue
|
||||
from typing import Dict, Tuple
|
||||
|
||||
|
||||
class OperationDependencyGraph:
|
||||
"""
|
||||
A DAG structure to enforce correct execution order of collective communication operations.
|
||||
Supports topological sorting based on rank/threadblock execution and signal/wait synchronization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.root_nodes: List["Node"] = []
|
||||
self.previous_node: Dict[Tuple[int, int], int] = {}
|
||||
self.signalling: Dict[Tuple[int, int, int], Queue] = {}
|
||||
self.waiting: Dict[Tuple[int, int, int], Queue] = {}
|
||||
|
||||
def add_operation(self, op: "Op"):
|
||||
"""
|
||||
Inserts an operation into the DAG, adding edges based on dependencies.
|
||||
"""
|
||||
node = self.Node(op)
|
||||
rank = op.rank
|
||||
tb = op.tb
|
||||
|
||||
if op.inst == Instruction.barrier:
|
||||
for tb in op.extra.get("tb_list", []):
|
||||
if (rank, tb) not in self.previous_node:
|
||||
self.previous_node[(rank, tb)] = node
|
||||
self.root_nodes.append(node)
|
||||
else:
|
||||
prev_node = self.previous_node[(rank, tb)]
|
||||
prev_node.next_nodes.append(node)
|
||||
node.input += 1
|
||||
self.previous_node[(rank, tb)] = node
|
||||
|
||||
else:
|
||||
if (rank, tb) not in self.previous_node:
|
||||
self.previous_node[(rank, tb)] = node
|
||||
self.root_nodes.append(node)
|
||||
else:
|
||||
prev_node = self.previous_node[(rank, tb)]
|
||||
prev_node.next_nodes.append(node)
|
||||
node.input += 1
|
||||
self.previous_node[(rank, tb)] = node
|
||||
|
||||
if op.inst == Instruction.signal:
|
||||
if (op.src.rank, op.dst.rank, tb) not in self.waiting or self.waiting[
|
||||
(op.src.rank, op.dst.rank, tb)
|
||||
].empty():
|
||||
if (op.src.rank, op.dst.rank, tb) not in self.signalling:
|
||||
self.signalling[(op.src.rank, op.dst.rank, tb)] = Queue()
|
||||
self.signalling[(op.src.rank, op.dst.rank, tb)].put(node)
|
||||
else:
|
||||
waiting_node = self.waiting[(op.src.rank, op.dst.rank, tb)].get()
|
||||
node.next_nodes.append(waiting_node)
|
||||
waiting_node.input += 1
|
||||
|
||||
if op.inst == Instruction.wait:
|
||||
if (op.src.rank, op.dst.rank, tb) not in self.signalling or self.signalling[
|
||||
(op.src.rank, op.dst.rank, tb)
|
||||
].empty():
|
||||
if (op.src.rank, op.dst.rank, tb) not in self.waiting:
|
||||
self.waiting[(op.src.rank, op.dst.rank, tb)] = Queue()
|
||||
self.waiting[(op.src.rank, op.dst.rank, tb)].put(node)
|
||||
else:
|
||||
signalling_node = self.signalling[(op.src.rank, op.dst.rank, tb)].get()
|
||||
signalling_node.next_nodes.append(node)
|
||||
node.input += 1
|
||||
|
||||
def get_execution_order(self):
|
||||
"""
|
||||
Returns the order of operations in the DAG.
|
||||
"""
|
||||
order = []
|
||||
queue = Queue()
|
||||
for node in self.root_nodes:
|
||||
queue.put(node)
|
||||
|
||||
while not queue.empty():
|
||||
node = queue.get()
|
||||
op = node.operation
|
||||
order.append(op)
|
||||
for next_node in node.next_nodes:
|
||||
next_node.reach += 1
|
||||
if next_node.reach == next_node.input:
|
||||
queue.put(next_node)
|
||||
|
||||
return order
|
||||
|
||||
class Node:
|
||||
operation: "Op"
|
||||
next_nodes: list
|
||||
input: int
|
||||
reach: int
|
||||
|
||||
def __init__(self, operation: "Op"):
|
||||
self.operation = operation
|
||||
self.next_nodes = []
|
||||
self.input = 0
|
||||
self.reach = 0
|
||||
Reference in New Issue
Block a user