Include a static synchronization check in the DSL. (#806)

This commit is contained in:
Caio Rocha
2026-05-19 13:06:53 -07:00
committed by GitHub
parent 60a6d7219f
commit c1071318c8
2 changed files with 38 additions and 0 deletions

View File

@@ -78,6 +78,7 @@ class MemoryChannel:
tb_channel_ids = get_program().setup_channel(tb, self)
op = SignalOperation(tb_channel_ids, self.channel_type, data_sync, relaxed)
get_program().add_operation(self.src_rank, tb, op)
get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type)
def wait(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = False):
"""Wait for a signal through the memory channel.
@@ -99,6 +100,7 @@ class MemoryChannel:
tb_channel_ids = get_program().setup_channel(tb, self)
op = WaitOperation(tb_channel_ids, self.channel_type, data_sync, relaxed)
get_program().add_operation(self.src_rank, tb, op)
get_program().register_wait(self.src_rank, self.dst_rank, self.channel_type)
def get(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None):
"""Retrieve data from remote memory to local memory.
@@ -508,6 +510,7 @@ class PortChannel:
tb_channel_ids = get_program().setup_channel(tb, self)
op = SignalOperation(tb_channel_ids, self.channel_type, data_sync)
get_program().add_operation(self.src_rank, tb, op)
get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type)
def wait(self, tb: int, data_sync: SyncType = SyncType.both):
"""Wait for a signal through the port channel.
@@ -527,6 +530,7 @@ class PortChannel:
tb_channel_ids = get_program().setup_channel(tb, self)
op = WaitOperation(tb_channel_ids, self.channel_type, data_sync)
get_program().add_operation(self.src_rank, tb, op)
get_program().register_wait(self.src_rank, self.dst_rank, self.channel_type)
def flush(self, tb: int, data_sync: SyncType = SyncType.both):
"""Flush pending operations through the port channel.
@@ -636,6 +640,7 @@ class PortChannel:
)
get_program().add_operation(self.src_rank, tb, op)
get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type)
def put_with_signal_and_flush(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
"""Send data from local memory to remote memory with signal and flush.
@@ -681,6 +686,7 @@ class PortChannel:
)
get_program().add_operation(self.src_rank, tb, op)
get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type)
def put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
"""Transfer data from local buffer to remote scratch buffer in packet format.

View File

@@ -10,6 +10,7 @@ from mscclpp.language.rank import Semaphore
from mscclpp.language.collectives import *
from mscclpp.language.utils import AlgoSpec, ReplicationPolicy
from typing import List
from collections import defaultdict
import json
@@ -112,6 +113,9 @@ class CollectiveProgram:
self.loop_context = None
self._signal_counts = defaultdict(int)
self._wait_counts = defaultdict(int)
@classmethod
def from_spec(cls, spec: AlgoSpec):
"""Initialize a new CollectiveProgram from an algorithm specification.
@@ -206,7 +210,35 @@ class CollectiveProgram:
else:
self.gpus[rank].add_operation(tb, operation)
def register_signal(self, src_rank, dst_rank, channel_type):
"""Record that `src_rank` issued a signal targeting `dst_rank` over `channel_type`."""
self._signal_counts[(src_rank, dst_rank, channel_type)] += 1
def register_wait(self, src_rank, dst_rank, channel_type):
"""Record that `src_rank` performed a wait for `dst_rank` over `channel_type`."""
self._wait_counts[(src_rank, dst_rank, channel_type)] += 1
def validate_signal_wait_pairing(self):
"""Validate that every signal issued by a rank is matched by a wait on the peer rank.
For each (src_rank, dst_rank, channel_type) triple, the number of signals sent by
`src_rank` to `dst_rank` must equal the number of waits performed by `dst_rank`
for `src_rank` on a channel of the same type. Raises RuntimeError on mismatch.
"""
keys = set(self._signal_counts.keys()) | {(dst, src, t) for (src, dst, t) in self._wait_counts.keys()}
for src_rank, dst_rank, channel_type in keys:
signals = self._signal_counts.get((src_rank, dst_rank, channel_type), 0)
waits = self._wait_counts.get((dst_rank, src_rank, channel_type), 0)
if signals != waits:
raise RuntimeError(
f"Signal/Wait mismatch on {channel_type}: rank {src_rank} issues {signals} "
f"signal(s) to rank {dst_rank}, but rank {dst_rank} performs {waits} wait(s) "
f"for rank {src_rank}. Every signal must be matched by a corresponding wait "
f"on the peer rank over a channel of the same type."
)
def post_process_operations(self):
self.validate_signal_wait_pairing()
for gpu in self.gpus:
if self.instr_fusion:
gpu.optimize_operations()