From 51eca89d20f0cfb3764ccd764338d7b22cd486a6 Mon Sep 17 00:00:00 2001 From: Caio Rocha <164253795+caiomcbr@users.noreply.github.com> Date: Tue, 29 Apr 2025 13:29:28 -0700 Subject: [PATCH] Enhance Collective Check at MSCCLang (#511) --- python/mscclpp/language/collective_checker.py | 153 ++++++++++++++++++ python/mscclpp/language/program.py | 76 ++++++--- python/mscclpp/language/topo_sort.py | 104 ++++++++++++ 3 files changed, 307 insertions(+), 26 deletions(-) create mode 100644 python/mscclpp/language/collective_checker.py create mode 100644 python/mscclpp/language/topo_sort.py diff --git a/python/mscclpp/language/collective_checker.py b/python/mscclpp/language/collective_checker.py new file mode 100644 index 00000000..7952172a --- /dev/null +++ b/python/mscclpp/language/collective_checker.py @@ -0,0 +1,153 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import List +from mscclpp.language.buffer import * +from mscclpp.language.collectives import Collective +from mscclpp.language.types import DataFormat, ChannelType, ChunkRef +from mscclpp.language.ir import * + + +class CollectiveChecker: + def __init__(self): + self.collective = None + self.buffers = None + + def apply_send(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size): + src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index) + dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index) + sb = self.buffers[src][src_buffer] + db = self.buffers[dst][dst_buffer] + for i in range(size): + db[dst_index + i] = sb[src_index + i] + + def apply_reduce(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size): + src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index) + dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index) + sb = self.buffers[src][src_buffer] + db = self.buffers[dst][dst_buffer] + for i in range(size): + reduce_chunk = db[dst_index + i] + sent_chunk = sb[src_index + i] + db[dst_index + i] = reduce_chunk.reduce(dst, sent_chunk) + + def check_buffer_exists(self, rank, name): + if name not in self.buffers[rank]: + self.buffers[rank][name] = BufferSlice(Buffer.scratch, name) + + def _get_buffer_index(self, src_rank, remote_rank, buffer, index): + if index == -1 and buffer == None: + return src_rank.buffer, src_rank.index + elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output: + return buffer, self.buffers[remote_rank][buffer].instance_size() + return buffer, index + + def _put(self, src, dst, buffer=None, index=-1): + self.check_buffer_exists(dst, buffer) + buffer, index = self._get_buffer_index(src, dst, buffer, index) + self.apply_send(src.rank, src.buffer, src.index, dst, buffer, index, src.size) + + def _put_packet( + self, + src, + dst, + buffer=None, + index=-1, + sendtb=-1, + src_format=DataFormat.raw, + chan_type=ChannelType.memory, + temp_buffer=None, + temp_buffer_index=-1, + ): + chunk_ref = src + if chan_type == ChannelType.port and src_format == DataFormat.raw: + self._copy(src.rank, temp_buffer, temp_buffer_index) + self._put(chunk_ref, dst, buffer, index) + + def _get(self, dst, src, buffer=None, index=-1): + self.check_buffer_exists(src, buffer) + buffer, index = self._get_buffer_index(dst, src, buffer, index) + self.apply_send(src, buffer, index, dst.rank, dst.buffer, dst.index, dst.size) + + def _copy(self, src, dst, buffer=None, index=-1): + self.check_buffer_exists(dst, buffer) + buffer, index = self._get_buffer_index(src, dst, buffer, index) + self.apply_send(src.rank, src.buffer, src.index, dst, buffer, index, src.size) + + def _reduce(self, src, other_chunkref): + dst = src.rank + src_rank = other_chunkref.rank + self.apply_reduce(src_rank, other_chunkref.buffer, other_chunkref.index, dst, src.buffer, src.index, src.size) + + def _group_load_reduce(self, src, other_chunkrefs: list): + for other_chunkref in other_chunkrefs: + src_chunkref = other_chunkref + self.apply_reduce( + src_chunkref.rank, + src_chunkref.buffer, + src_chunkref.index, + src.rank, + src.buffer, + src.index, + src.size, + ) + + def _group_store(self, src, dsts: list, index=-1, buffer=None): + for dst in dsts: + self.check_buffer_exists(dst, buffer) + + for dst in dsts: + buffer, index = self._get_buffer_index(src, dst, buffer, index) + self.apply_send(src.rank, src.buffer, src.index, dst, buffer, index, src.size) + + def _execute(self, operations): + for op in operations: + if op.inst == Instruction.put: + src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size) + self._put(src, op.dst.rank, op.dst.buffer, op.dst.index) + elif op.inst == Instruction.put_packet: + src_format = op.extra.get("src_format") + temp_buffer = op.extra.get("temp_buffer") + temp_buffer_index = op.extra.get("temp_buffer_index") + src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size) + self._put_packet( + src, + op.dst.rank, + op.dst.buffer, + op.dst.index, + sendtb=op.tb, + src_format=src_format, + chan_type=op.channel_type, + temp_buffer=temp_buffer, + temp_buffer_index=temp_buffer_index, + ) + elif op.inst == Instruction.get: + dst = ChunkRef(op.dst.rank, op.dst.buffer, op.dst.index, op.dst.size) + self._get(dst, op.src.rank, op.src.buffer, op.src.index) + elif op.inst == Instruction.copy: + src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size) + self._copy(src, op.dst.rank, op.dst.buffer, op.dst.index) + elif op.inst == Instruction.copy_packet: + src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size) + self._copy(src, op.dst.rank, op.dst.buffer, op.dst.index) + elif op.inst == Instruction.reduce: + src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size) + self._reduce(src, ChunkRef(op.dst.rank, op.dst.buffer, op.dst.index, op.dst.size)) + elif op.inst == Instruction.reduce_packet: + src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size) + self._reduce(src, ChunkRef(op.dst.rank, op.dst.buffer, op.dst.index, op.dst.size)) + elif op.inst == Instruction.group_load_reduce: + src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size) + self._group_load_reduce(src, other_chunkrefs=op.srcs) + elif op.inst == Instruction.group_store: + dsts = op.extra.get("dsts") + index = op.extra.get("index") + buffer = op.extra.get("buffer") + src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size) + self._group_store(src, dsts=dsts, index=index, buffer=buffer) + + def check(self, collective: Collective, operations: List): + self.collective = collective + self.buffers = collective.init_buffers() + self._execute(operations) + return collective.check(self) diff --git a/python/mscclpp/language/program.py b/python/mscclpp/language/program.py index 681a2fe3..9e5a03ea 100644 --- a/python/mscclpp/language/program.py +++ b/python/mscclpp/language/program.py @@ -8,6 +8,8 @@ from mscclpp.language.types import DataFormat, ChannelType, ChunkRef, Replicatio from mscclpp.language.ir import * from mscclpp.language.dag import DagOptimizer, DagLower, InstructionDAG from mscclpp.language.rank import Rank +from mscclpp.language.topo_sort import OperationDependencyGraph +from mscclpp.language.collective_checker import CollectiveChecker _current_program = None @@ -54,6 +56,8 @@ class MSCCLPPProgram: # Initialize the input buffers self.buffers = collective.init_buffers() self.instr_dag = InstructionDAG(self.num_ranks, self.buffers) + self.op_dep_dag = OperationDependencyGraph() + self.collective_checker = CollectiveChecker() self.ranks = [] for r in range(self.num_ranks): self.ranks.append(Rank(r)) @@ -98,17 +102,6 @@ class MSCCLPPProgram: for i in range(size): db[dst_index + i] = sb[src_index + i] - # Tracks a reduce operation on the buffers - def apply_reduce(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size): - src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index) - dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index) - sb = self.buffers[src][src_buffer] - db = self.buffers[dst][dst_buffer] - for i in range(size): - reduce_chunk = db[dst_index + i] - sent_chunk = sb[src_index + i] - db[dst_index + i] = reduce_chunk.reduce(dst, sent_chunk) - def get_ref(self, rank, buffer, index, size): buffer, index = self.collective.get_buffer_index(rank, buffer, index) return Ref(rank, buffer, index, size, self) @@ -129,7 +122,7 @@ class MSCCLPPProgram: # Checks that all chunks that should be on each rank # are present in the output buffer. def check(self): - return self.collective.check(self) + return self.collective_checker.check(self.collective, self.op_dep_dag.get_execution_order()) # Lower program to MSCCLPP def lower(self): @@ -248,6 +241,9 @@ class Ref(ChunkRef): return dst_chunkref def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.memory): + op = Op(inst=Instruction.put, rank=self.rank, src=self, dst=ChunkRef(dst, buffer, index, self.size)) + self.prog.op_dep_dag.add_operation(op) + return self._put(dst, buffer, index, sendtb, DataFormat.raw, chan_type) def put_packet( @@ -261,6 +257,16 @@ class Ref(ChunkRef): temp_buffer=None, temp_buffer_index=-1, ): + extra = {"src_format": src_format, "temp_buffer": temp_buffer, "temp_buffer_index": temp_buffer_index} + op = Op( + inst=Instruction.put_packet, + rank=self.rank, + src=self, + dst=ChunkRef(dst, buffer, index, self.size), + extra=extra, + ) + self.prog.op_dep_dag.add_operation(op) + chunk_ref = self if chan_type == ChannelType.port and src_format == DataFormat.raw: assert temp_buffer is not None, "Need to specify a temporary buffer for port channels" @@ -270,6 +276,9 @@ class Ref(ChunkRef): return chunk_ref._put(dst, buffer, index, sendtb, src_format, chan_type, True) def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.memory): + op = Op(inst=Instruction.get, rank=self.rank, src=ChunkRef(src, buffer, index, self.size), dst=self) + self.prog.op_dep_dag.add_operation(op) + self.prog.check_buffer_exists(src, buffer) sender = src receiver = self.rank @@ -285,6 +294,9 @@ class Ref(ChunkRef): # to infer the tb index from the instruction DAG Add a channel is define as (send_tb, src_buffer, recv_tb, dst_buffer, type). # Then we can use DAG info to reduce the number of channels. def signal(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.memory): + op = Op(inst=Instruction.signal, rank=self.rank, src=self, dst=ChunkRef(dst, buffer, index, self.size)) + self.prog.op_dep_dag.add_operation(op) + sender = self.rank receiver = dst assert sender != receiver, "Cannot signal to the same rank" @@ -295,6 +307,9 @@ class Ref(ChunkRef): # only port channel need to use this function def flush(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.port): + op = Op(inst=Instruction.flush, rank=self.rank, src=self, dst=ChunkRef(dst, buffer, index, self.size)) + self.prog.op_dep_dag.add_operation(op) + assert chan_type == ChannelType.port, "Only port channel can use flush" sender = self.rank receiver = dst @@ -305,6 +320,9 @@ class Ref(ChunkRef): self.prog.instr_dag.add_flush(sender, self, dst_chunkref, sendtb) def wait(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.memory): + op = Op(inst=Instruction.wait, rank=self.rank, src=ChunkRef(src, buffer, index, self.size), dst=self) + self.prog.op_dep_dag.add_operation(op) + sender = src receiver = self.rank assert sender != receiver, "Cannot wait on the same rank" @@ -330,18 +348,21 @@ class Ref(ChunkRef): # Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index) def copy(self, dst, buffer=None, index=-1, sendtb=-1): + op = Op(inst=Instruction.copy, rank=self.rank, src=self, dst=ChunkRef(dst, buffer, index, self.size)) + self.prog.op_dep_dag.add_operation(op) + return self._copy(dst, buffer, index, sendtb) def copy_packet(self, dst, buffer=None, index=-1, sendtb=-1): + op = Op(inst=Instruction.copy_packet, rank=self.rank, src=self, dst=ChunkRef(dst, buffer, index, self.size)) + self.prog.op_dep_dag.add_operation(op) + return self._copy(dst, buffer, index, sendtb, trans_from_packet=True, trans_to_packet=False) def _reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.memory, use_packet=False): dst = self.rank src = other_chunkref.rank - self.prog.apply_reduce( - src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size - ) if use_packet: assert src == dst, "Packet reduce only supports intra-rank communication" @@ -354,10 +375,16 @@ class Ref(ChunkRef): # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref def reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.memory): + op = Op(inst=Instruction.reduce, rank=self.rank, src=self, dst=other_chunkref) + self.prog.op_dep_dag.add_operation(op) + return self._reduce(other_chunkref, recvtb, channel_type) # Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref def reduce_packet(self, other_chunkref, recvtb=-1): + op = Op(inst=Instruction.reduce_packet, rank=self.rank, src=self, dst=other_chunkref) + self.prog.op_dep_dag.add_operation(op) + return self._reduce(other_chunkref, recvtb, use_packet=True) # """ @@ -366,6 +393,9 @@ class Ref(ChunkRef): # """ # Reads the chunk(s) referenced by other_chunkref and reduce into the chunk referenced by this chunkref def group_load_reduce(self, other_chunkrefs: list, recvtb=-1, chan_type=ChannelType.nvls): + op = Op(inst=Instruction.group_load_reduce, rank=self.rank, src=self, dst=None, srcs=other_chunkrefs) + self.prog.op_dep_dag.add_operation(op) + assert ( len(other_chunkrefs) > 0 and chan_type == ChannelType.nvls ), "Group load reduce only supports nvls channel" @@ -377,21 +407,15 @@ class Ref(ChunkRef): assert self.buffer == other_chunkref.buffer, "Group load reduce only supports chunks with the same buffer" assert self.index == other_chunkref.index, "Group load reduce only supports chunks with the same index" - src_chunkref = other_chunkref - self.prog.apply_reduce( - src_chunkref.rank, - src_chunkref.buffer, - src_chunkref.index, - self.rank, - self.buffer, - self.index, - self.size, - ) self.prog.instr_dag.add_group_load_reduce(self.rank, other_chunkrefs, self, recvtb, chan_type) return self # Copies the chunk(s) referenced by this chunkref onto other_chunkrefs def group_store(self, dsts: list, index=-1, buffer=None, sendtb=-1, chan_type=ChannelType.nvls): + extra = {"dsts": dsts, "index": index, "buffer": buffer} + op = Op(inst=Instruction.group_store, rank=self.rank, src=self, dst=None, extra=extra) + self.prog.op_dep_dag.add_operation(op) + for dst in dsts: self.prog.check_buffer_exists(dst, buffer) assert index == -1 or self.index == index, "Group store only supports chunks with the same index" @@ -434,7 +458,7 @@ def chunk(rank, buffer, index, size=1) -> Ref: if buffer is Buffer.scratch: if buffer not in _curr().buffers[rank]: _curr().buffers[rank][buffer] = BufferSlice(Buffer.scratch, buffer) - if index >= len(_curr().buffers[rank][buffer]): + if index >= len(_curr().buffers[rank][buffer]) or _curr().buffers[rank][buffer][index] is None: _curr().buffers[rank][buffer][index] = ChunkRef(rank, buffer, index, size) if _curr().buffers[rank][buffer][index] is None: diff --git a/python/mscclpp/language/topo_sort.py b/python/mscclpp/language/topo_sort.py new file mode 100644 index 00000000..3254d4fe --- /dev/null +++ b/python/mscclpp/language/topo_sort.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from mscclpp.language.types import * +from queue import Queue +from typing import Dict, Tuple + + +class OperationDependencyGraph: + """ + A DAG structure to enforce correct execution order of collective communication operations. + Supports topological sorting based on rank/threadblock execution and signal/wait synchronization. + """ + + def __init__(self): + self.root_nodes: List["Node"] = [] + self.previous_node: Dict[Tuple[int, int], int] = {} + self.signalling: Dict[Tuple[int, int, int], Queue] = {} + self.waiting: Dict[Tuple[int, int, int], Queue] = {} + + def add_operation(self, op: "Op"): + """ + Inserts an operation into the DAG, adding edges based on dependencies. + """ + node = self.Node(op) + rank = op.rank + tb = op.tb + + if op.inst == Instruction.barrier: + for tb in op.extra.get("tb_list", []): + if (rank, tb) not in self.previous_node: + self.previous_node[(rank, tb)] = node + self.root_nodes.append(node) + else: + prev_node = self.previous_node[(rank, tb)] + prev_node.next_nodes.append(node) + node.input += 1 + self.previous_node[(rank, tb)] = node + + else: + if (rank, tb) not in self.previous_node: + self.previous_node[(rank, tb)] = node + self.root_nodes.append(node) + else: + prev_node = self.previous_node[(rank, tb)] + prev_node.next_nodes.append(node) + node.input += 1 + self.previous_node[(rank, tb)] = node + + if op.inst == Instruction.signal: + if (op.src.rank, op.dst.rank, tb) not in self.waiting or self.waiting[ + (op.src.rank, op.dst.rank, tb) + ].empty(): + if (op.src.rank, op.dst.rank, tb) not in self.signalling: + self.signalling[(op.src.rank, op.dst.rank, tb)] = Queue() + self.signalling[(op.src.rank, op.dst.rank, tb)].put(node) + else: + waiting_node = self.waiting[(op.src.rank, op.dst.rank, tb)].get() + node.next_nodes.append(waiting_node) + waiting_node.input += 1 + + if op.inst == Instruction.wait: + if (op.src.rank, op.dst.rank, tb) not in self.signalling or self.signalling[ + (op.src.rank, op.dst.rank, tb) + ].empty(): + if (op.src.rank, op.dst.rank, tb) not in self.waiting: + self.waiting[(op.src.rank, op.dst.rank, tb)] = Queue() + self.waiting[(op.src.rank, op.dst.rank, tb)].put(node) + else: + signalling_node = self.signalling[(op.src.rank, op.dst.rank, tb)].get() + signalling_node.next_nodes.append(node) + node.input += 1 + + def get_execution_order(self): + """ + Returns the order of operations in the DAG. + """ + order = [] + queue = Queue() + for node in self.root_nodes: + queue.put(node) + + while not queue.empty(): + node = queue.get() + op = node.operation + order.append(op) + for next_node in node.next_nodes: + next_node.reach += 1 + if next_node.reach == next_node.input: + queue.put(next_node) + + return order + + class Node: + operation: "Op" + next_nodes: list + input: int + reach: int + + def __init__(self, operation: "Op"): + self.operation = operation + self.next_nodes = [] + self.input = 0 + self.reach = 0