Enhance Collective Check at MSCCLang (#511)

This commit is contained in:
Caio Rocha
2025-04-29 13:29:28 -07:00
committed by GitHub
parent affca7d9bc
commit 51eca89d20
3 changed files with 307 additions and 26 deletions

View File

@@ -0,0 +1,153 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import List
from mscclpp.language.buffer import *
from mscclpp.language.collectives import Collective
from mscclpp.language.types import DataFormat, ChannelType, ChunkRef
from mscclpp.language.ir import *
class CollectiveChecker:
def __init__(self):
self.collective = None
self.buffers = None
def apply_send(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size):
src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index)
dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index)
sb = self.buffers[src][src_buffer]
db = self.buffers[dst][dst_buffer]
for i in range(size):
db[dst_index + i] = sb[src_index + i]
def apply_reduce(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size):
src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index)
dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index)
sb = self.buffers[src][src_buffer]
db = self.buffers[dst][dst_buffer]
for i in range(size):
reduce_chunk = db[dst_index + i]
sent_chunk = sb[src_index + i]
db[dst_index + i] = reduce_chunk.reduce(dst, sent_chunk)
def check_buffer_exists(self, rank, name):
if name not in self.buffers[rank]:
self.buffers[rank][name] = BufferSlice(Buffer.scratch, name)
def _get_buffer_index(self, src_rank, remote_rank, buffer, index):
if index == -1 and buffer == None:
return src_rank.buffer, src_rank.index
elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output:
return buffer, self.buffers[remote_rank][buffer].instance_size()
return buffer, index
def _put(self, src, dst, buffer=None, index=-1):
self.check_buffer_exists(dst, buffer)
buffer, index = self._get_buffer_index(src, dst, buffer, index)
self.apply_send(src.rank, src.buffer, src.index, dst, buffer, index, src.size)
def _put_packet(
self,
src,
dst,
buffer=None,
index=-1,
sendtb=-1,
src_format=DataFormat.raw,
chan_type=ChannelType.memory,
temp_buffer=None,
temp_buffer_index=-1,
):
chunk_ref = src
if chan_type == ChannelType.port and src_format == DataFormat.raw:
self._copy(src.rank, temp_buffer, temp_buffer_index)
self._put(chunk_ref, dst, buffer, index)
def _get(self, dst, src, buffer=None, index=-1):
self.check_buffer_exists(src, buffer)
buffer, index = self._get_buffer_index(dst, src, buffer, index)
self.apply_send(src, buffer, index, dst.rank, dst.buffer, dst.index, dst.size)
def _copy(self, src, dst, buffer=None, index=-1):
self.check_buffer_exists(dst, buffer)
buffer, index = self._get_buffer_index(src, dst, buffer, index)
self.apply_send(src.rank, src.buffer, src.index, dst, buffer, index, src.size)
def _reduce(self, src, other_chunkref):
dst = src.rank
src_rank = other_chunkref.rank
self.apply_reduce(src_rank, other_chunkref.buffer, other_chunkref.index, dst, src.buffer, src.index, src.size)
def _group_load_reduce(self, src, other_chunkrefs: list):
for other_chunkref in other_chunkrefs:
src_chunkref = other_chunkref
self.apply_reduce(
src_chunkref.rank,
src_chunkref.buffer,
src_chunkref.index,
src.rank,
src.buffer,
src.index,
src.size,
)
def _group_store(self, src, dsts: list, index=-1, buffer=None):
for dst in dsts:
self.check_buffer_exists(dst, buffer)
for dst in dsts:
buffer, index = self._get_buffer_index(src, dst, buffer, index)
self.apply_send(src.rank, src.buffer, src.index, dst, buffer, index, src.size)
def _execute(self, operations):
for op in operations:
if op.inst == Instruction.put:
src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size)
self._put(src, op.dst.rank, op.dst.buffer, op.dst.index)
elif op.inst == Instruction.put_packet:
src_format = op.extra.get("src_format")
temp_buffer = op.extra.get("temp_buffer")
temp_buffer_index = op.extra.get("temp_buffer_index")
src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size)
self._put_packet(
src,
op.dst.rank,
op.dst.buffer,
op.dst.index,
sendtb=op.tb,
src_format=src_format,
chan_type=op.channel_type,
temp_buffer=temp_buffer,
temp_buffer_index=temp_buffer_index,
)
elif op.inst == Instruction.get:
dst = ChunkRef(op.dst.rank, op.dst.buffer, op.dst.index, op.dst.size)
self._get(dst, op.src.rank, op.src.buffer, op.src.index)
elif op.inst == Instruction.copy:
src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size)
self._copy(src, op.dst.rank, op.dst.buffer, op.dst.index)
elif op.inst == Instruction.copy_packet:
src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size)
self._copy(src, op.dst.rank, op.dst.buffer, op.dst.index)
elif op.inst == Instruction.reduce:
src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size)
self._reduce(src, ChunkRef(op.dst.rank, op.dst.buffer, op.dst.index, op.dst.size))
elif op.inst == Instruction.reduce_packet:
src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size)
self._reduce(src, ChunkRef(op.dst.rank, op.dst.buffer, op.dst.index, op.dst.size))
elif op.inst == Instruction.group_load_reduce:
src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size)
self._group_load_reduce(src, other_chunkrefs=op.srcs)
elif op.inst == Instruction.group_store:
dsts = op.extra.get("dsts")
index = op.extra.get("index")
buffer = op.extra.get("buffer")
src = ChunkRef(op.src.rank, op.src.buffer, op.src.index, op.src.size)
self._group_store(src, dsts=dsts, index=index, buffer=buffer)
def check(self, collective: Collective, operations: List):
self.collective = collective
self.buffers = collective.init_buffers()
self._execute(operations)
return collective.check(self)

View File

@@ -8,6 +8,8 @@ from mscclpp.language.types import DataFormat, ChannelType, ChunkRef, Replicatio
from mscclpp.language.ir import *
from mscclpp.language.dag import DagOptimizer, DagLower, InstructionDAG
from mscclpp.language.rank import Rank
from mscclpp.language.topo_sort import OperationDependencyGraph
from mscclpp.language.collective_checker import CollectiveChecker
_current_program = None
@@ -54,6 +56,8 @@ class MSCCLPPProgram:
# Initialize the input buffers
self.buffers = collective.init_buffers()
self.instr_dag = InstructionDAG(self.num_ranks, self.buffers)
self.op_dep_dag = OperationDependencyGraph()
self.collective_checker = CollectiveChecker()
self.ranks = []
for r in range(self.num_ranks):
self.ranks.append(Rank(r))
@@ -98,17 +102,6 @@ class MSCCLPPProgram:
for i in range(size):
db[dst_index + i] = sb[src_index + i]
# Tracks a reduce operation on the buffers
def apply_reduce(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size):
src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index)
dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index)
sb = self.buffers[src][src_buffer]
db = self.buffers[dst][dst_buffer]
for i in range(size):
reduce_chunk = db[dst_index + i]
sent_chunk = sb[src_index + i]
db[dst_index + i] = reduce_chunk.reduce(dst, sent_chunk)
def get_ref(self, rank, buffer, index, size):
buffer, index = self.collective.get_buffer_index(rank, buffer, index)
return Ref(rank, buffer, index, size, self)
@@ -129,7 +122,7 @@ class MSCCLPPProgram:
# Checks that all chunks that should be on each rank
# are present in the output buffer.
def check(self):
return self.collective.check(self)
return self.collective_checker.check(self.collective, self.op_dep_dag.get_execution_order())
# Lower program to MSCCLPP
def lower(self):
@@ -248,6 +241,9 @@ class Ref(ChunkRef):
return dst_chunkref
def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.memory):
op = Op(inst=Instruction.put, rank=self.rank, src=self, dst=ChunkRef(dst, buffer, index, self.size))
self.prog.op_dep_dag.add_operation(op)
return self._put(dst, buffer, index, sendtb, DataFormat.raw, chan_type)
def put_packet(
@@ -261,6 +257,16 @@ class Ref(ChunkRef):
temp_buffer=None,
temp_buffer_index=-1,
):
extra = {"src_format": src_format, "temp_buffer": temp_buffer, "temp_buffer_index": temp_buffer_index}
op = Op(
inst=Instruction.put_packet,
rank=self.rank,
src=self,
dst=ChunkRef(dst, buffer, index, self.size),
extra=extra,
)
self.prog.op_dep_dag.add_operation(op)
chunk_ref = self
if chan_type == ChannelType.port and src_format == DataFormat.raw:
assert temp_buffer is not None, "Need to specify a temporary buffer for port channels"
@@ -270,6 +276,9 @@ class Ref(ChunkRef):
return chunk_ref._put(dst, buffer, index, sendtb, src_format, chan_type, True)
def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.memory):
op = Op(inst=Instruction.get, rank=self.rank, src=ChunkRef(src, buffer, index, self.size), dst=self)
self.prog.op_dep_dag.add_operation(op)
self.prog.check_buffer_exists(src, buffer)
sender = src
receiver = self.rank
@@ -285,6 +294,9 @@ class Ref(ChunkRef):
# to infer the tb index from the instruction DAG Add a channel is define as (send_tb, src_buffer, recv_tb, dst_buffer, type).
# Then we can use DAG info to reduce the number of channels.
def signal(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.memory):
op = Op(inst=Instruction.signal, rank=self.rank, src=self, dst=ChunkRef(dst, buffer, index, self.size))
self.prog.op_dep_dag.add_operation(op)
sender = self.rank
receiver = dst
assert sender != receiver, "Cannot signal to the same rank"
@@ -295,6 +307,9 @@ class Ref(ChunkRef):
# only port channel need to use this function
def flush(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.port):
op = Op(inst=Instruction.flush, rank=self.rank, src=self, dst=ChunkRef(dst, buffer, index, self.size))
self.prog.op_dep_dag.add_operation(op)
assert chan_type == ChannelType.port, "Only port channel can use flush"
sender = self.rank
receiver = dst
@@ -305,6 +320,9 @@ class Ref(ChunkRef):
self.prog.instr_dag.add_flush(sender, self, dst_chunkref, sendtb)
def wait(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.memory):
op = Op(inst=Instruction.wait, rank=self.rank, src=ChunkRef(src, buffer, index, self.size), dst=self)
self.prog.op_dep_dag.add_operation(op)
sender = src
receiver = self.rank
assert sender != receiver, "Cannot wait on the same rank"
@@ -330,18 +348,21 @@ class Ref(ChunkRef):
# Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index)
def copy(self, dst, buffer=None, index=-1, sendtb=-1):
op = Op(inst=Instruction.copy, rank=self.rank, src=self, dst=ChunkRef(dst, buffer, index, self.size))
self.prog.op_dep_dag.add_operation(op)
return self._copy(dst, buffer, index, sendtb)
def copy_packet(self, dst, buffer=None, index=-1, sendtb=-1):
op = Op(inst=Instruction.copy_packet, rank=self.rank, src=self, dst=ChunkRef(dst, buffer, index, self.size))
self.prog.op_dep_dag.add_operation(op)
return self._copy(dst, buffer, index, sendtb, trans_from_packet=True, trans_to_packet=False)
def _reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.memory, use_packet=False):
dst = self.rank
src = other_chunkref.rank
self.prog.apply_reduce(
src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size
)
if use_packet:
assert src == dst, "Packet reduce only supports intra-rank communication"
@@ -354,10 +375,16 @@ class Ref(ChunkRef):
# Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref
def reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.memory):
op = Op(inst=Instruction.reduce, rank=self.rank, src=self, dst=other_chunkref)
self.prog.op_dep_dag.add_operation(op)
return self._reduce(other_chunkref, recvtb, channel_type)
# Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref
def reduce_packet(self, other_chunkref, recvtb=-1):
op = Op(inst=Instruction.reduce_packet, rank=self.rank, src=self, dst=other_chunkref)
self.prog.op_dep_dag.add_operation(op)
return self._reduce(other_chunkref, recvtb, use_packet=True)
# """
@@ -366,6 +393,9 @@ class Ref(ChunkRef):
# """
# Reads the chunk(s) referenced by other_chunkref and reduce into the chunk referenced by this chunkref
def group_load_reduce(self, other_chunkrefs: list, recvtb=-1, chan_type=ChannelType.nvls):
op = Op(inst=Instruction.group_load_reduce, rank=self.rank, src=self, dst=None, srcs=other_chunkrefs)
self.prog.op_dep_dag.add_operation(op)
assert (
len(other_chunkrefs) > 0 and chan_type == ChannelType.nvls
), "Group load reduce only supports nvls channel"
@@ -377,21 +407,15 @@ class Ref(ChunkRef):
assert self.buffer == other_chunkref.buffer, "Group load reduce only supports chunks with the same buffer"
assert self.index == other_chunkref.index, "Group load reduce only supports chunks with the same index"
src_chunkref = other_chunkref
self.prog.apply_reduce(
src_chunkref.rank,
src_chunkref.buffer,
src_chunkref.index,
self.rank,
self.buffer,
self.index,
self.size,
)
self.prog.instr_dag.add_group_load_reduce(self.rank, other_chunkrefs, self, recvtb, chan_type)
return self
# Copies the chunk(s) referenced by this chunkref onto other_chunkrefs
def group_store(self, dsts: list, index=-1, buffer=None, sendtb=-1, chan_type=ChannelType.nvls):
extra = {"dsts": dsts, "index": index, "buffer": buffer}
op = Op(inst=Instruction.group_store, rank=self.rank, src=self, dst=None, extra=extra)
self.prog.op_dep_dag.add_operation(op)
for dst in dsts:
self.prog.check_buffer_exists(dst, buffer)
assert index == -1 or self.index == index, "Group store only supports chunks with the same index"
@@ -434,7 +458,7 @@ def chunk(rank, buffer, index, size=1) -> Ref:
if buffer is Buffer.scratch:
if buffer not in _curr().buffers[rank]:
_curr().buffers[rank][buffer] = BufferSlice(Buffer.scratch, buffer)
if index >= len(_curr().buffers[rank][buffer]):
if index >= len(_curr().buffers[rank][buffer]) or _curr().buffers[rank][buffer][index] is None:
_curr().buffers[rank][buffer][index] = ChunkRef(rank, buffer, index, size)
if _curr().buffers[rank][buffer][index] is None:

View File

@@ -0,0 +1,104 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from mscclpp.language.types import *
from queue import Queue
from typing import Dict, Tuple
class OperationDependencyGraph:
"""
A DAG structure to enforce correct execution order of collective communication operations.
Supports topological sorting based on rank/threadblock execution and signal/wait synchronization.
"""
def __init__(self):
self.root_nodes: List["Node"] = []
self.previous_node: Dict[Tuple[int, int], int] = {}
self.signalling: Dict[Tuple[int, int, int], Queue] = {}
self.waiting: Dict[Tuple[int, int, int], Queue] = {}
def add_operation(self, op: "Op"):
"""
Inserts an operation into the DAG, adding edges based on dependencies.
"""
node = self.Node(op)
rank = op.rank
tb = op.tb
if op.inst == Instruction.barrier:
for tb in op.extra.get("tb_list", []):
if (rank, tb) not in self.previous_node:
self.previous_node[(rank, tb)] = node
self.root_nodes.append(node)
else:
prev_node = self.previous_node[(rank, tb)]
prev_node.next_nodes.append(node)
node.input += 1
self.previous_node[(rank, tb)] = node
else:
if (rank, tb) not in self.previous_node:
self.previous_node[(rank, tb)] = node
self.root_nodes.append(node)
else:
prev_node = self.previous_node[(rank, tb)]
prev_node.next_nodes.append(node)
node.input += 1
self.previous_node[(rank, tb)] = node
if op.inst == Instruction.signal:
if (op.src.rank, op.dst.rank, tb) not in self.waiting or self.waiting[
(op.src.rank, op.dst.rank, tb)
].empty():
if (op.src.rank, op.dst.rank, tb) not in self.signalling:
self.signalling[(op.src.rank, op.dst.rank, tb)] = Queue()
self.signalling[(op.src.rank, op.dst.rank, tb)].put(node)
else:
waiting_node = self.waiting[(op.src.rank, op.dst.rank, tb)].get()
node.next_nodes.append(waiting_node)
waiting_node.input += 1
if op.inst == Instruction.wait:
if (op.src.rank, op.dst.rank, tb) not in self.signalling or self.signalling[
(op.src.rank, op.dst.rank, tb)
].empty():
if (op.src.rank, op.dst.rank, tb) not in self.waiting:
self.waiting[(op.src.rank, op.dst.rank, tb)] = Queue()
self.waiting[(op.src.rank, op.dst.rank, tb)].put(node)
else:
signalling_node = self.signalling[(op.src.rank, op.dst.rank, tb)].get()
signalling_node.next_nodes.append(node)
node.input += 1
def get_execution_order(self):
"""
Returns the order of operations in the DAG.
"""
order = []
queue = Queue()
for node in self.root_nodes:
queue.put(node)
while not queue.empty():
node = queue.get()
op = node.operation
order.append(op)
for next_node in node.next_nodes:
next_node.reach += 1
if next_node.reach == next_node.input:
queue.put(next_node)
return order
class Node:
operation: "Op"
next_nodes: list
input: int
reach: int
def __init__(self, operation: "Op"):
self.operation = operation
self.next_nodes = []
self.input = 0
self.reach = 0