mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-25 15:24:43 +00:00
Include a static synchronization check in the DSL. (#806)
This commit is contained in:
@@ -78,6 +78,7 @@ class MemoryChannel:
|
||||
tb_channel_ids = get_program().setup_channel(tb, self)
|
||||
op = SignalOperation(tb_channel_ids, self.channel_type, data_sync, relaxed)
|
||||
get_program().add_operation(self.src_rank, tb, op)
|
||||
get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type)
|
||||
|
||||
def wait(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = False):
|
||||
"""Wait for a signal through the memory channel.
|
||||
@@ -99,6 +100,7 @@ class MemoryChannel:
|
||||
tb_channel_ids = get_program().setup_channel(tb, self)
|
||||
op = WaitOperation(tb_channel_ids, self.channel_type, data_sync, relaxed)
|
||||
get_program().add_operation(self.src_rank, tb, op)
|
||||
get_program().register_wait(self.src_rank, self.dst_rank, self.channel_type)
|
||||
|
||||
def get(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None):
|
||||
"""Retrieve data from remote memory to local memory.
|
||||
@@ -508,6 +510,7 @@ class PortChannel:
|
||||
tb_channel_ids = get_program().setup_channel(tb, self)
|
||||
op = SignalOperation(tb_channel_ids, self.channel_type, data_sync)
|
||||
get_program().add_operation(self.src_rank, tb, op)
|
||||
get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type)
|
||||
|
||||
def wait(self, tb: int, data_sync: SyncType = SyncType.both):
|
||||
"""Wait for a signal through the port channel.
|
||||
@@ -527,6 +530,7 @@ class PortChannel:
|
||||
tb_channel_ids = get_program().setup_channel(tb, self)
|
||||
op = WaitOperation(tb_channel_ids, self.channel_type, data_sync)
|
||||
get_program().add_operation(self.src_rank, tb, op)
|
||||
get_program().register_wait(self.src_rank, self.dst_rank, self.channel_type)
|
||||
|
||||
def flush(self, tb: int, data_sync: SyncType = SyncType.both):
|
||||
"""Flush pending operations through the port channel.
|
||||
@@ -636,6 +640,7 @@ class PortChannel:
|
||||
)
|
||||
|
||||
get_program().add_operation(self.src_rank, tb, op)
|
||||
get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type)
|
||||
|
||||
def put_with_signal_and_flush(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
|
||||
"""Send data from local memory to remote memory with signal and flush.
|
||||
@@ -681,6 +686,7 @@ class PortChannel:
|
||||
)
|
||||
|
||||
get_program().add_operation(self.src_rank, tb, op)
|
||||
get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type)
|
||||
|
||||
def put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
|
||||
"""Transfer data from local buffer to remote scratch buffer in packet format.
|
||||
|
||||
@@ -10,6 +10,7 @@ from mscclpp.language.rank import Semaphore
|
||||
from mscclpp.language.collectives import *
|
||||
from mscclpp.language.utils import AlgoSpec, ReplicationPolicy
|
||||
from typing import List
|
||||
from collections import defaultdict
|
||||
import json
|
||||
|
||||
|
||||
@@ -112,6 +113,9 @@ class CollectiveProgram:
|
||||
|
||||
self.loop_context = None
|
||||
|
||||
self._signal_counts = defaultdict(int)
|
||||
self._wait_counts = defaultdict(int)
|
||||
|
||||
@classmethod
|
||||
def from_spec(cls, spec: AlgoSpec):
|
||||
"""Initialize a new CollectiveProgram from an algorithm specification.
|
||||
@@ -206,7 +210,35 @@ class CollectiveProgram:
|
||||
else:
|
||||
self.gpus[rank].add_operation(tb, operation)
|
||||
|
||||
def register_signal(self, src_rank, dst_rank, channel_type):
|
||||
"""Record that `src_rank` issued a signal targeting `dst_rank` over `channel_type`."""
|
||||
self._signal_counts[(src_rank, dst_rank, channel_type)] += 1
|
||||
|
||||
def register_wait(self, src_rank, dst_rank, channel_type):
|
||||
"""Record that `src_rank` performed a wait for `dst_rank` over `channel_type`."""
|
||||
self._wait_counts[(src_rank, dst_rank, channel_type)] += 1
|
||||
|
||||
def validate_signal_wait_pairing(self):
|
||||
"""Validate that every signal issued by a rank is matched by a wait on the peer rank.
|
||||
|
||||
For each (src_rank, dst_rank, channel_type) triple, the number of signals sent by
|
||||
`src_rank` to `dst_rank` must equal the number of waits performed by `dst_rank`
|
||||
for `src_rank` on a channel of the same type. Raises RuntimeError on mismatch.
|
||||
"""
|
||||
keys = set(self._signal_counts.keys()) | {(dst, src, t) for (src, dst, t) in self._wait_counts.keys()}
|
||||
for src_rank, dst_rank, channel_type in keys:
|
||||
signals = self._signal_counts.get((src_rank, dst_rank, channel_type), 0)
|
||||
waits = self._wait_counts.get((dst_rank, src_rank, channel_type), 0)
|
||||
if signals != waits:
|
||||
raise RuntimeError(
|
||||
f"Signal/Wait mismatch on {channel_type}: rank {src_rank} issues {signals} "
|
||||
f"signal(s) to rank {dst_rank}, but rank {dst_rank} performs {waits} wait(s) "
|
||||
f"for rank {src_rank}. Every signal must be matched by a corresponding wait "
|
||||
f"on the peer rank over a channel of the same type."
|
||||
)
|
||||
|
||||
def post_process_operations(self):
|
||||
self.validate_signal_wait_pairing()
|
||||
for gpu in self.gpus:
|
||||
if self.instr_fusion:
|
||||
gpu.optimize_operations()
|
||||
|
||||
Reference in New Issue
Block a user