mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 09:46:00 +00:00
434 lines
18 KiB
Python
434 lines
18 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
from dataclasses import dataclass
|
|
from mscclpp.language.collectives import Collective
|
|
from mscclpp.language.buffer import *
|
|
from mscclpp.language.types import ChannelType, ChunkRef, ReplicationPolicy, Threadblock
|
|
from mscclpp.language.ir import *
|
|
from mscclpp.language.dag import DagOptimizer, DagLower, InstructionDAG
|
|
from mscclpp.language.rank import Rank
|
|
|
|
_current_program = None
|
|
|
|
|
|
def _curr():
|
|
global _current_program
|
|
if _current_program == None:
|
|
raise RuntimeError("No Program in context")
|
|
return _current_program
|
|
|
|
|
|
# For msccl++ program, we have one assumption that for channel can be identified by (send_buffer, recv_buffer, type, send_tb/recv_tb)
|
|
# which means the send_tb and recv_tb should be the same for a pair of signal and wait, also same for put/get operation.
|
|
# If one sender what to send data to peer want to use different tb in receiver side. We need to send to same tb in receiver side first,
|
|
# then performance a across tb sync. This is a limitation of current implementation.
|
|
class MSCCLPPProgram:
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
collective: Collective,
|
|
num_ranks: int,
|
|
instances: int,
|
|
protocol: str = "Simple",
|
|
instr_fusion: bool = True,
|
|
replication_policy: ReplicationPolicy = ReplicationPolicy.duplicated,
|
|
num_threads_per_block: int = 1024,
|
|
use_double_scratch_buffer: bool = False,
|
|
min_message_size: int = 0,
|
|
max_message_size: int = 2**64 - 1,
|
|
):
|
|
self.name = name
|
|
self.collective = collective
|
|
self.num_ranks = num_ranks
|
|
self.instances = instances
|
|
self.protocol = protocol
|
|
self.instr_fusion = instr_fusion
|
|
self.replication_policy = replication_policy
|
|
self.num_threads_per_block = num_threads_per_block
|
|
self.use_double_scratch_buffer = use_double_scratch_buffer
|
|
self.min_message_size = min_message_size
|
|
self.max_message_size = max_message_size
|
|
assert protocol == "Simple" or protocol == "LL", f"Given protocol: {protocol}. Must be either Simple, LL"
|
|
self.run_opt = True # Runs optimization passes
|
|
# Initialize the input buffers
|
|
self.buffers = collective.init_buffers()
|
|
self.instr_dag = InstructionDAG(self.num_ranks, self.buffers)
|
|
self.ranks = []
|
|
for r in range(self.num_ranks):
|
|
self.ranks.append(Rank(r))
|
|
for index, chunk in enumerate(self.buffers[r][Buffer.input]):
|
|
buffer, index = self.collective.get_buffer_index(r, Buffer.input, index)
|
|
ref = self.get_ref(r, buffer, index, 1)
|
|
# self.chunk_dag.init_chunk(chunk, ref)
|
|
self.instr_dag.add_start(r, buffer, index, ref)
|
|
|
|
def __enter__(self):
|
|
global _current_program
|
|
if _current_program != None:
|
|
raise RuntimeError("There is already a MSCCLPP Program in context")
|
|
_current_program = self
|
|
|
|
def __exit__(self, exc_type, exc_value, exc_traceback):
|
|
global _current_program
|
|
if _current_program != self:
|
|
raise RuntimeError("This program is not currently in context")
|
|
_current_program = None
|
|
|
|
def _convert_to_execution_plan(self):
|
|
ops = self.instr_dag.convert_set_list()
|
|
ops = sorted(ops, key=lambda x: x.step)
|
|
for op in ops:
|
|
rank = op.rank
|
|
tbid = op.tb
|
|
if tbid not in self.instr_dag.tbs[rank]:
|
|
self.instr_dag.tbs[rank][tbid] = Threadblock(id=tbid)
|
|
tb = self.instr_dag.tbs[rank][tbid]
|
|
tb.ops.append(op)
|
|
|
|
def get_rank_ref(self, rank):
|
|
return RankRef(rank, self)
|
|
|
|
# Tracks a send operation on the buffers
|
|
def apply_send(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size):
|
|
src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index)
|
|
dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index)
|
|
sb = self.buffers[src][src_buffer]
|
|
db = self.buffers[dst][dst_buffer]
|
|
for i in range(size):
|
|
db[dst_index + i] = sb[src_index + i]
|
|
|
|
# Tracks a reduce operation on the buffers
|
|
def apply_reduce(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size):
|
|
src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index)
|
|
dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index)
|
|
sb = self.buffers[src][src_buffer]
|
|
db = self.buffers[dst][dst_buffer]
|
|
for i in range(size):
|
|
reduce_chunk = db[dst_index + i]
|
|
sent_chunk = sb[src_index + i]
|
|
db[dst_index + i] = reduce_chunk.reduce(dst, sent_chunk)
|
|
|
|
def get_ref(self, rank, buffer, index, size):
|
|
buffer, index = self.collective.get_buffer_index(rank, buffer, index)
|
|
return Ref(rank, buffer, index, size, self)
|
|
|
|
def get_chunks(self, rank, buffer, index, size=1):
|
|
chunks = [None] * size
|
|
for i in range(0, size):
|
|
if self.buffers[rank][buffer] and index + i < len(self.buffers[rank][buffer]):
|
|
chunks[i] = self.buffers[rank][buffer][index + i]
|
|
else:
|
|
chunks[i] = None
|
|
return chunks
|
|
|
|
def check_buffer_exists(self, rank, name):
|
|
if name not in self.buffers[rank]:
|
|
self.buffers[rank][name] = BufferSlice(Buffer.scratch, name)
|
|
|
|
# Checks that all chunks that should be on each rank
|
|
# are present in the output buffer.
|
|
def check(self):
|
|
return self.collective.check(self)
|
|
|
|
# Lower program to MSCCLPP
|
|
def lower(self):
|
|
self._convert_to_execution_plan()
|
|
self.instr_dag.complete_channels()
|
|
dag_optimizer = DagOptimizer(self.instr_dag)
|
|
dag_optimizer.remove_redundant_signal_wait()
|
|
if self.instr_fusion:
|
|
dag_optimizer.fuse_instructions()
|
|
dag_lower = DagLower(self.instr_dag)
|
|
gpu_prgms = dag_lower.lower(self.instances, self.replication_policy)
|
|
program = Program(
|
|
self.name,
|
|
self.collective.name,
|
|
self.collective.inplace,
|
|
self.protocol,
|
|
gpu_prgms,
|
|
self.collective.num_chunk_groups * self.instances,
|
|
self.num_threads_per_block,
|
|
self.use_double_scratch_buffer,
|
|
self.min_message_size,
|
|
self.max_message_size,
|
|
)
|
|
for gpu in program.gpus:
|
|
gpu.input_chunks = len(self.buffers[gpu.rank][Buffer.input]) * self.instances
|
|
if not self.collective.inplace:
|
|
gpu.output_chunks = len(self.buffers[gpu.rank][Buffer.output]) * self.instances
|
|
return program
|
|
|
|
def generate_json(self):
|
|
return ir_to_json(self.lower())
|
|
|
|
|
|
def Json():
|
|
print(_curr().generate_json())
|
|
|
|
|
|
@dataclass
|
|
class RankRef:
|
|
rank: int
|
|
prog: MSCCLPPProgram
|
|
|
|
def _get_barrier_id(self, tb_list) -> int:
|
|
return self.prog.ranks[self.rank].get_barrier_id(tb_list)
|
|
|
|
def barrier(self, tb_list):
|
|
barrier_id = self._get_barrier_id(tb_list)
|
|
return self.prog.instr_dag.add_barrier(self.rank, tb_list, barrier_id)
|
|
|
|
|
|
@dataclass
|
|
class Ref(ChunkRef):
|
|
prog: MSCCLPPProgram
|
|
|
|
def __repr__(self):
|
|
return f"Ref(Buffer:{self.buffer}, Index:{self.index}, Size:{self.size}, Rank:{self.rank})"
|
|
|
|
def _end(self):
|
|
return self.index + self.size
|
|
|
|
def _get_chunk(self, index):
|
|
return self.prog.buffers[self.rank][self.buffer][index]
|
|
|
|
def split(self, num):
|
|
assert self.size % num == 0, f"Trying to split a chunk of {self.size} elements into {num} parts"
|
|
chunks = [None] * num
|
|
size = self.size // num
|
|
for i in range(num):
|
|
index = self.index + i * size
|
|
chunks[i] = self.prog.get_ref(self.rank, self.buffer, index, size)
|
|
return chunks
|
|
|
|
def group(self, other):
|
|
assert self.rank == other.rank, f"Trying to concatenate chunks on ranks {self.rank} and {other.rank}"
|
|
assert self.buffer == other.buffer, f"Trying to concatenate chunks in {self.buffer} and {other.buffer}"
|
|
if self.index < other.index:
|
|
first = self
|
|
second = other
|
|
else:
|
|
first = other
|
|
second = self
|
|
|
|
end = max(first._end(), second._end())
|
|
return Ref(self.rank, self.buffer, first.index, end - first.index, self.prog)
|
|
|
|
def _get_buffer_index(self, remote_rank, buffer, index):
|
|
if index == -1 and buffer == None:
|
|
return self.buffer, self.index
|
|
elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output:
|
|
return buffer, self.prog.buffers[remote_rank][buffer].instance_size()
|
|
return buffer, index
|
|
|
|
def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.memory, use_packet=False):
|
|
self.prog.check_buffer_exists(dst, buffer)
|
|
assert self.rank != dst, "Cannot put to the same rank"
|
|
buffer, index = self._get_buffer_index(dst, buffer, index)
|
|
|
|
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
|
|
self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size)
|
|
if use_packet:
|
|
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, True)
|
|
self.prog.instr_dag.add_signal(self.rank, self, dst_chunkref, -1, ChannelType.none)
|
|
self.prog.instr_dag.add_wait(dst, dst_chunkref, self, -1, ChannelType.none)
|
|
else:
|
|
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type)
|
|
return dst_chunkref
|
|
|
|
def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.memory):
|
|
return self._put(dst, buffer, index, sendtb, chan_type)
|
|
|
|
def put_packet(
|
|
self,
|
|
dst,
|
|
buffer=None,
|
|
index=-1,
|
|
sendtb=-1,
|
|
chan_type=ChannelType.memory,
|
|
temp_buffer=None,
|
|
temp_buffer_index=-1,
|
|
):
|
|
chunk_ref = self
|
|
if chan_type == ChannelType.port:
|
|
assert temp_buffer is not None, "Need to specify a temporary buffer for port channels"
|
|
chunk_ref = self._copy(
|
|
self.rank, temp_buffer, temp_buffer_index, sendtb, trans_from_packet=False, trans_to_packet=True
|
|
)
|
|
return chunk_ref._put(dst, buffer, index, sendtb, chan_type, True)
|
|
|
|
def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.memory):
|
|
self.prog.check_buffer_exists(src, buffer)
|
|
sender = src
|
|
receiver = self.rank
|
|
assert sender != receiver, "Cannot get from the same rank"
|
|
buffer, index = self._get_buffer_index(src, buffer, index)
|
|
|
|
src_chunkref = self.prog.get_ref(src, buffer, index, self.size)
|
|
|
|
self.prog.apply_send(src, buffer, index, self.rank, self.buffer, self.index, self.size)
|
|
self.prog.instr_dag.add_get(receiver, src_chunkref, self, recvtb, chan_type)
|
|
|
|
# for signal and wait, currently we assuem the pair will use the same tb index. In future we need
|
|
# to infer the tb index from the instruction DAG Add a channel is define as (send_tb, src_buffer, recv_tb, dst_buffer, type).
|
|
# Then we can use DAG info to reduce the number of channels.
|
|
def signal(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.memory):
|
|
sender = self.rank
|
|
receiver = dst
|
|
assert sender != receiver, "Cannot signal to the same rank"
|
|
buffer, index = self._get_buffer_index(dst, buffer, index)
|
|
|
|
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
|
|
self.prog.instr_dag.add_signal(sender, self, dst_chunkref, sendtb, chan_type)
|
|
|
|
# only port channel need to use this function
|
|
def flush(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.port):
|
|
assert chan_type == ChannelType.port, "Only port channel can use flush"
|
|
sender = self.rank
|
|
receiver = dst
|
|
assert sender != receiver, "Cannot flush to the same rank"
|
|
buffer, index = self._get_buffer_index(dst, buffer, index)
|
|
|
|
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
|
|
self.prog.instr_dag.add_flush(sender, self, dst_chunkref, sendtb)
|
|
|
|
def wait(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.memory):
|
|
sender = src
|
|
receiver = self.rank
|
|
assert sender != receiver, "Cannot wait on the same rank"
|
|
buffer, index = self._get_buffer_index(src, buffer, index)
|
|
|
|
src_chunkref = self.prog.get_ref(src, buffer, index, self.size)
|
|
self.prog.instr_dag.add_wait(receiver, self, src_chunkref, recvtb, chan_type)
|
|
|
|
def _copy(self, dst, buffer=None, index=-1, sendtb=-1, trans_from_packet=False, trans_to_packet=False):
|
|
self.prog.check_buffer_exists(dst, buffer)
|
|
buffer, index = self._get_buffer_index(dst, buffer, index)
|
|
|
|
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
|
|
# Check if we are copying the chunk to the same index (easy mistake when we are using inplace)
|
|
if dst_chunkref == self:
|
|
return
|
|
self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size)
|
|
|
|
assert self.rank == dst, "Chunk copy only supports intra-rank communication"
|
|
self.prog.instr_dag.add_copy(self.rank, self, dst_chunkref, sendtb, trans_from_packet, trans_to_packet)
|
|
|
|
return dst_chunkref
|
|
|
|
# Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index)
|
|
def copy(self, dst, buffer=None, index=-1, sendtb=-1):
|
|
return self._copy(dst, buffer, index, sendtb)
|
|
|
|
def copy_packet(self, dst, buffer=None, index=-1, sendtb=-1):
|
|
return self._copy(dst, buffer, index, sendtb, trans_from_packet=True, trans_to_packet=False)
|
|
|
|
def _reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.memory, use_packet=False):
|
|
dst = self.rank
|
|
src = other_chunkref.rank
|
|
|
|
self.prog.apply_reduce(
|
|
src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size
|
|
)
|
|
if use_packet:
|
|
assert src == dst, "Packet reduce only supports intra-rank communication"
|
|
|
|
if src != dst:
|
|
self.prog.instr_dag.add_read_reduce(dst, other_chunkref, self, recvtb, channel_type)
|
|
else:
|
|
self.prog.instr_dag.add_reduce(src, other_chunkref, self, recvtb, use_packet)
|
|
|
|
return self
|
|
|
|
# Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref
|
|
def reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.memory):
|
|
return self._reduce(other_chunkref, recvtb, channel_type)
|
|
|
|
# Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref
|
|
def reduce_packet(self, other_chunkref, recvtb=-1):
|
|
return self._reduce(other_chunkref, recvtb, use_packet=True)
|
|
|
|
# """
|
|
# Group operations. These operations are used to perform collective operations across multiple chunks.
|
|
# For now, all chunks must has the same buffer type and offset.
|
|
# """
|
|
# Reads the chunk(s) referenced by other_chunkref and reduce into the chunk referenced by this chunkref
|
|
def group_load_reduce(self, other_chunkrefs: list, recvtb=-1, chan_type=ChannelType.nvls):
|
|
assert (
|
|
len(other_chunkrefs) > 0 and chan_type == ChannelType.nvls
|
|
), "Group load reduce only supports nvls channel"
|
|
nranks_per_node = self.prog.collective.num_ranks_per_node
|
|
for other_chunkref in other_chunkrefs:
|
|
assert (
|
|
self.rank // nranks_per_node == other_chunkref.rank // nranks_per_node
|
|
), "Group load reduce only supports chunks on the same node"
|
|
assert self.buffer == other_chunkref.buffer, "Group load reduce only supports chunks with the same buffer"
|
|
assert self.index == other_chunkref.index, "Group load reduce only supports chunks with the same index"
|
|
|
|
src_chunkref = other_chunkref
|
|
self.prog.apply_reduce(
|
|
src_chunkref.rank,
|
|
src_chunkref.buffer,
|
|
src_chunkref.index,
|
|
self.rank,
|
|
self.buffer,
|
|
self.index,
|
|
self.size,
|
|
)
|
|
self.prog.instr_dag.add_group_load_reduce(self.rank, other_chunkrefs, self, recvtb, chan_type)
|
|
return self
|
|
|
|
# Copies the chunk(s) referenced by this chunkref onto other_chunkrefs
|
|
def group_store(self, dsts: list, index=-1, buffer=None, sendtb=-1, chan_type=ChannelType.nvls):
|
|
for dst in dsts:
|
|
self.prog.check_buffer_exists(dst, buffer)
|
|
assert index == -1 or self.index == index, "Group store only supports chunks with the same index"
|
|
assert chan_type == ChannelType.nvls, "Group store only supports nvls channel"
|
|
|
|
other_chunkrefs = []
|
|
nrank_per_node = self.prog.collective.num_ranks_per_node
|
|
for dst in dsts:
|
|
# Direct linked
|
|
buffer, index = self._get_buffer_index(dst, buffer, index)
|
|
|
|
assert self.buffer == buffer, "Group store only supports chunks with the same buffer"
|
|
assert (
|
|
self.rank // nrank_per_node == dst // nrank_per_node
|
|
), "Group store only supports chunks on the same node"
|
|
|
|
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
|
|
self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size)
|
|
other_chunkrefs.append(dst_chunkref)
|
|
# add new op here
|
|
self.prog.instr_dag.add_group_store(self.rank, self, other_chunkrefs, sendtb, chan_type)
|
|
|
|
def get_origin_index(self, index=0):
|
|
return self._get_chunk(index + self.index).origin_index
|
|
|
|
def get_origin_rank(self, index=0):
|
|
return self._get_chunk(index + self.index).origin_rank
|
|
|
|
def get_dst_index(self, index=0):
|
|
return self._get_chunk(index + self.index).dst_index
|
|
|
|
def get_dst_rank(self, index=0):
|
|
return self._get_chunk(index + self.index).dst_rank
|
|
|
|
def print_chunk_info(self, index=0):
|
|
print(self._get_chunk(index + self.index))
|
|
|
|
|
|
def chunk(rank, buffer, index, size=1) -> Ref:
|
|
if _curr().buffers[rank][buffer][index] is None:
|
|
return None
|
|
return _curr().get_ref(rank, buffer, index, size)
|
|
|
|
|
|
def rank(rank) -> RankRef:
|
|
return _curr().get_rank_ref(rank)
|
|
|
|
|
|
def Check():
|
|
return _curr().check()
|