From c1071318c84f968bc292e2ef9b8296ba837d06af Mon Sep 17 00:00:00 2001 From: Caio Rocha <164253795+caiomcbr@users.noreply.github.com> Date: Tue, 19 May 2026 13:06:53 -0700 Subject: [PATCH] Include a static synchronization check in the DSL. (#806) --- python/mscclpp/language/channel.py | 6 ++++++ python/mscclpp/language/program.py | 32 ++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/python/mscclpp/language/channel.py b/python/mscclpp/language/channel.py index 23d76eda..de0f65c5 100644 --- a/python/mscclpp/language/channel.py +++ b/python/mscclpp/language/channel.py @@ -78,6 +78,7 @@ class MemoryChannel: tb_channel_ids = get_program().setup_channel(tb, self) op = SignalOperation(tb_channel_ids, self.channel_type, data_sync, relaxed) get_program().add_operation(self.src_rank, tb, op) + get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type) def wait(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = False): """Wait for a signal through the memory channel. @@ -99,6 +100,7 @@ class MemoryChannel: tb_channel_ids = get_program().setup_channel(tb, self) op = WaitOperation(tb_channel_ids, self.channel_type, data_sync, relaxed) get_program().add_operation(self.src_rank, tb, op) + get_program().register_wait(self.src_rank, self.dst_rank, self.channel_type) def get(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None): """Retrieve data from remote memory to local memory. @@ -508,6 +510,7 @@ class PortChannel: tb_channel_ids = get_program().setup_channel(tb, self) op = SignalOperation(tb_channel_ids, self.channel_type, data_sync) get_program().add_operation(self.src_rank, tb, op) + get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type) def wait(self, tb: int, data_sync: SyncType = SyncType.both): """Wait for a signal through the port channel. @@ -527,6 +530,7 @@ class PortChannel: tb_channel_ids = get_program().setup_channel(tb, self) op = WaitOperation(tb_channel_ids, self.channel_type, data_sync) get_program().add_operation(self.src_rank, tb, op) + get_program().register_wait(self.src_rank, self.dst_rank, self.channel_type) def flush(self, tb: int, data_sync: SyncType = SyncType.both): """Flush pending operations through the port channel. @@ -636,6 +640,7 @@ class PortChannel: ) get_program().add_operation(self.src_rank, tb, op) + get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type) def put_with_signal_and_flush(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int): """Send data from local memory to remote memory with signal and flush. @@ -681,6 +686,7 @@ class PortChannel: ) get_program().add_operation(self.src_rank, tb, op) + get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type) def put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int): """Transfer data from local buffer to remote scratch buffer in packet format. diff --git a/python/mscclpp/language/program.py b/python/mscclpp/language/program.py index c29e9ab7..825a9d40 100644 --- a/python/mscclpp/language/program.py +++ b/python/mscclpp/language/program.py @@ -10,6 +10,7 @@ from mscclpp.language.rank import Semaphore from mscclpp.language.collectives import * from mscclpp.language.utils import AlgoSpec, ReplicationPolicy from typing import List +from collections import defaultdict import json @@ -112,6 +113,9 @@ class CollectiveProgram: self.loop_context = None + self._signal_counts = defaultdict(int) + self._wait_counts = defaultdict(int) + @classmethod def from_spec(cls, spec: AlgoSpec): """Initialize a new CollectiveProgram from an algorithm specification. @@ -206,7 +210,35 @@ class CollectiveProgram: else: self.gpus[rank].add_operation(tb, operation) + def register_signal(self, src_rank, dst_rank, channel_type): + """Record that `src_rank` issued a signal targeting `dst_rank` over `channel_type`.""" + self._signal_counts[(src_rank, dst_rank, channel_type)] += 1 + + def register_wait(self, src_rank, dst_rank, channel_type): + """Record that `src_rank` performed a wait for `dst_rank` over `channel_type`.""" + self._wait_counts[(src_rank, dst_rank, channel_type)] += 1 + + def validate_signal_wait_pairing(self): + """Validate that every signal issued by a rank is matched by a wait on the peer rank. + + For each (src_rank, dst_rank, channel_type) triple, the number of signals sent by + `src_rank` to `dst_rank` must equal the number of waits performed by `dst_rank` + for `src_rank` on a channel of the same type. Raises RuntimeError on mismatch. + """ + keys = set(self._signal_counts.keys()) | {(dst, src, t) for (src, dst, t) in self._wait_counts.keys()} + for src_rank, dst_rank, channel_type in keys: + signals = self._signal_counts.get((src_rank, dst_rank, channel_type), 0) + waits = self._wait_counts.get((dst_rank, src_rank, channel_type), 0) + if signals != waits: + raise RuntimeError( + f"Signal/Wait mismatch on {channel_type}: rank {src_rank} issues {signals} " + f"signal(s) to rank {dst_rank}, but rank {dst_rank} performs {waits} wait(s) " + f"for rank {src_rank}. Every signal must be matched by a corresponding wait " + f"on the peer rank over a channel of the same type." + ) + def post_process_operations(self): + self.validate_signal_wait_pairing() for gpu in self.gpus: if self.instr_fusion: gpu.optimize_operations()