Merge mscclpp-lang to mscclpp project (#442)

First step to merge msccl-tools into mscclpp repo. In this step will
move all msccl related code, pass the current tests and do some
necessary refactor.

Add `mscclpp.language` module
Add `_InstructionOptimizer` and `DagOptimizer` class to optimize the dag
Add `DagLower` to lower dag to intermediate representation 
Add documents for mscclpp.language
Remove msccl related code
This commit is contained in:
Binyang Li
2025-01-22 09:47:37 -08:00
committed by GitHub
parent 4ee15b7ad0
commit af0bb86e07
28 changed files with 3417 additions and 18 deletions

View File

@@ -0,0 +1,55 @@
import argparse
from mscclpp.language import *
from mscclpp.language.buffer import Buffer
from mscclpp.language.collectives import AllGather
from mscclpp.language.types import ChannelType, ReplicationPolicy
def allgather_test(gpus, instances):
"""
Demonstrates how to use barrier in the MSCCL++ DSL with an allgather collective.
This example uses an allpairs algorithm for the allgather operation.
Steps:
1. Each rank sends a chunk to all other ranks' output buffers and copies the chunk to its own output buffer.
2. A barrier is called to synchronize the send and copy operations, and signal peers that the data has been sent.
3. Wait for all the chunks from other ranks to be received.
"""
size = gpus
collective = AllGather(size, 1, False)
with MSCCLPPProgram(
"allgather_with_barrier",
collective,
size,
instances,
protocol="Simple",
replication_policy=ReplicationPolicy.interleaved,
):
for n in range(gpus):
c = chunk(n, Buffer.input, 0, 1)
for peer in range(gpus):
if n != peer:
c.put(peer, Buffer.output, n, sendtb=peer, chan_type=ChannelType.sm)
else:
c.copy(n, Buffer.output, n, sendtb=peer)
# explicit barrier
r = rank(n)
r.barrier(tb_list=list(range(gpus)))
for peer in range(gpus):
if n != peer:
c.signal(peer, Buffer.output, n, sendtb=peer, chan_type=ChannelType.sm)
for n in range(gpus):
for peer in range(gpus):
c = chunk(n, Buffer.output, peer, 1)
if n != peer:
c.wait(peer, Buffer.input, peer, recvtb=peer, chan_type=ChannelType.sm)
Json()
Check()
parser = argparse.ArgumentParser()
parser.add_argument("num_gpus", type=int, help="number of gpus")
parser.add_argument("instances", type=int, help="number of instances")
args = parser.parse_args()
allgather_test(args.num_gpus, args.instances)

View File

@@ -0,0 +1,65 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
from mscclpp.language import *
from mscclpp.language.collectives import AllReduce
from mscclpp.language.buffer import Buffer
def allreduce_allpairs(gpus, instances, protocol):
"""
Demonstrate allreduce with all pairs algorithm using put semantics.
Steps:
1. Sync all ranks to ensure the data is ready.
2. Each rank reads chunks from all peers and reduces the data.
3. Put the reduced data to all peers.
4. Sync all ranks to ensure the data is received.
"""
size = gpus
chunksperloop = gpus * gpus
collective = AllReduce(size, chunksperloop, True)
with MSCCLPPProgram("allreduce_pairs", collective, size, instances, protocol=protocol):
for rank in range(size):
for tb in range(size):
index = rank * size
c = chunk(rank, Buffer.input, index + tb)
# step1 make sure the data is ready
for nghr in range(size):
peer_index = nghr * size
if rank != nghr:
# signal peer the buffer is ready
c_peer = chunk(rank, Buffer.input, peer_index + tb)
c_peer.signal(nghr, Buffer.input, peer_index + tb, sendtb=tb)
for nghr in range(size):
if rank != nghr:
c.wait(nghr, Buffer.input, index + tb, recvtb=tb)
# step2 reduce the chunks and send to peers
for nghr in range(size):
if rank != nghr:
c.reduce(chunk(nghr, Buffer.input, index + tb), recvtb=tb)
for nghr in range(size):
if rank != nghr:
c.put(nghr, Buffer.input, index + tb, sendtb=tb)
# step3 signal the peers buffer is ready
for nghr in range(size):
if rank != nghr:
c.signal(nghr, Buffer.input, index + tb, sendtb=tb)
for nghr in range(size):
if rank != nghr:
peer_index = nghr * size
c_peer = chunk(rank, Buffer.input, peer_index + tb)
c_peer.wait(nghr, Buffer.input, peer_index + tb, recvtb=tb)
Json()
Check()
parser = argparse.ArgumentParser()
parser.add_argument("num_gpus", type=int, help="number of gpus")
parser.add_argument("instances", type=int, help="number of instances")
parser.add_argument("--protocol", type=str, default="Simple", choices=["Simple"], help="Protocol")
args = parser.parse_args()
allreduce_allpairs(args.num_gpus, args.instances, args.protocol)

View File

@@ -0,0 +1,78 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
from mscclpp.language import *
from mscclpp.language.collectives import AllReduce
from mscclpp.language.buffer import Buffer
def allreduce_allpairs(gpus, instances):
"""
AllReduce with all pairs algorithm using get semantics.
Steps:
1. Sync all ranks to ensure the data is ready.
2. Each rank read chunks from all peers and reduces the data.
3. Signal all ranks to notify that the data is ready.
4. Wait for all chunks to be ready, then retrieve the chunks from all peers.
"""
size = gpus
chunksperloop = gpus * gpus
collective = AllReduce(size, chunksperloop, True)
with MSCCLPPProgram(
"allreduce_pairs",
collective,
size,
instances,
protocol="Simple",
):
# Each rank sends the nth chunk to the nth rank into scratch space
for rank in range(size):
for tb in range(size):
index = rank * size
c = chunk(rank, Buffer.input, index + tb)
# make sure the data is ready
for nghr in range(size):
peer_index = nghr * size
if rank != nghr:
c_peer = chunk(rank, Buffer.input, peer_index + tb)
c_peer.signal(nghr, Buffer.input, peer_index + tb, sendtb=tb)
for nghr in range(size):
if rank != nghr:
c.wait(nghr, Buffer.input, index + tb, recvtb=tb)
# reduce the chunks
for i in range(size):
nghr = (rank + i) % size
if rank != nghr:
c.reduce(chunk(nghr, Buffer.input, index + tb), recvtb=tb)
for nghr in range(size):
if rank != nghr:
c.signal(nghr, Buffer.input, index + tb, sendtb=tb)
# wait for all the chunks is ready, then get the chunks
for rank in range(size):
for tb in range(size):
for nghr in range(size):
if rank != nghr:
index = nghr * size
c = chunk(rank, Buffer.input, index + tb)
c.wait(nghr, Buffer.input, index + tb, recvtb=tb)
for i in range(size):
nghr = (rank + i) % size
index = nghr * size
if rank != nghr:
c = chunk(rank, Buffer.input, index + tb)
c.get(nghr, Buffer.input, index + tb, recvtb=tb)
Json()
Check()
parser = argparse.ArgumentParser()
parser.add_argument("num_gpus", type=int, help="number of gpus")
parser.add_argument("instances", type=int, help="number of instances")
args = parser.parse_args()
allreduce_allpairs(args.num_gpus, args.instances)

View File

@@ -0,0 +1,69 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
from mscclpp.language import *
from mscclpp.language.collectives import AllReduce
from mscclpp.language.buffer import Buffer
def allreduce_allpairs(gpus, instances):
"""
AllReduce with all pairs algorithm using packets format.
Steps:
1. Each rank sends its nth chunk to the nth rank's scratch space.
2. Each rank performs a local reduction on its nth chunk using data from all other ranks' scratch spaces.
3. Each rank sends the reduced data to all other ranks' scratch spaces.
4. Each rank retrieves the final reduced result from the scratch space.
"""
size = gpus
chunksperloop = gpus * gpus
collective = AllReduce(size, chunksperloop, True)
with MSCCLPPProgram(
"allreduce_packets",
collective,
size,
instances,
protocol="LL",
use_double_scratch_buffer=True,
):
# Each rank sends the nth chunk to the nth rank into scratch space
for r1 in range(size):
for tb in range(size):
if tb == r1:
continue
remote_rank = tb
index = remote_rank * size
c = chunk(r1, Buffer.input, index, size)
c.put_packet(remote_rank, "scratch", index=r1 * size, sendtb=tb)
# Each rank performs a local reduction on the nth chunk
# Utilize 8 threadblocks for this reduction for better parallelism
for r in range(size):
for index in range(size):
c = chunk(r, Buffer.input, r * size + index)
for peer in range(size):
if peer != r:
c.reduce_packet(chunk(r, "scratch", peer * size + index), recvtb=index)
for peer in range(size):
if peer != r:
c.put_packet(peer, "scratch", (size * size) + r * size + index, sendtb=index)
# Each rank get final result from scratch space
for r in range(size):
for peer in range(size):
if peer != r:
c = chunk(r, "scratch", size * size + peer * size, size)
c.copy_packet(r, Buffer.input, peer * size, sendtb=peer)
Json()
Check()
parser = argparse.ArgumentParser()
parser.add_argument("num_gpus", type=int, help="number of gpus")
parser.add_argument("instances", type=int, help="number of instances")
args = parser.parse_args()
allreduce_allpairs(args.num_gpus, args.instances)

View File

@@ -0,0 +1,55 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
from mscclpp.language import *
from mscclpp.language.collectives import AllReduce
from mscclpp.language.buffer import Buffer
def allreduce_nvls(gpus, instances):
"""
Allreduce via NVLS channel
Steps:
1. Sync all the ranks to make sure the data is ready.
2. Call group_load_reduce to reduce the data.
3. Call group_store to propagate the data to all the ranks.
"""
size = gpus
chunksperloop = gpus
collective = AllReduce(size, chunksperloop, True)
with MSCCLPPProgram(
"allreduce_nvls",
collective,
size,
instances,
):
# Each rank sends the nth chunk to the nth rank into scratch space
for rank in range(size):
index = rank
c = chunk(rank, Buffer.input, index)
reduce_chunks = []
# make sure the data is ready
for nghr in range(size):
if rank != nghr:
c_peer = chunk(nghr, Buffer.input, index)
reduce_chunks.append(c_peer)
c.signal(nghr, Buffer.input, index, sendtb=0)
for nghr in range(size):
if rank != nghr:
c.wait(nghr, Buffer.input, index, recvtb=0)
c = c.group_load_reduce(reduce_chunks, recvtb=0)
ngbrs = [nghr for nghr in range(size) if nghr != rank]
c.group_store(ngbrs, sendtb=0)
Json()
Check()
parser = argparse.ArgumentParser()
parser.add_argument("num_gpus", type=int, help="number of gpus")
parser.add_argument("instances", type=int, help="number of instances")
args = parser.parse_args()
allreduce_nvls(args.num_gpus, args.instances)

View File

@@ -0,0 +1,59 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
from mscclpp.language import *
from mscclpp.language.collectives import AllReduce
from mscclpp.language.buffer import Buffer
def allreduce_ring(size, instances):
"""
Implements a ring based allreduce.
Steps:
1. Send signal to next rank and wait for signal from previous rank. Make sure the data is ready in previous rank.
2. Reduce the data and send to next rank.
3. After all the data is reduced, propagate the data to all the ranks.
"""
collective = AllReduce(size, size, True)
with MSCCLPPProgram(
f"allreduce_ring",
collective,
size,
instances,
protocol="Simple",
):
# Reduce ring
for step in range(0, size - 1):
for index in range(0, size):
rank = (index + step) % size
next_rank = (index + step + 1) % size
c = chunk(rank, Buffer.input, index)
c.signal(next_rank, Buffer.input, index, 0)
prev_rank = (index + step - 1) % size
c = chunk(rank, Buffer.input, (index + size - 1) % size)
c.wait(prev_rank, Buffer.input, (index + size - 1) % size, 0)
c.reduce(chunk(prev_rank, Buffer.input, (index + size - 1) % size), recvtb=0)
# Propagate ring
for step in range(-1, size - 2):
for index in range(0, size):
rank = (index + step) % size
c = chunk(rank, Buffer.input, index)
next_rank = (index + step + 1) % size
c.put(next_rank, Buffer.input, index, sendtb=0)
c.signal(next_rank, Buffer.input, index, 0)
prev_rank = (index + step - 1) % size
c = chunk(rank, Buffer.input, (index + size - 1) % size)
c.wait(prev_rank, Buffer.input, (index + size - 1) % size, 0)
Json()
Check()
parser = argparse.ArgumentParser()
parser.add_argument("num_gpus", type=int, help="number of gpus")
parser.add_argument("instances", type=int, help="number of instances")
args = parser.parse_args()
allreduce_ring(args.num_gpus, args.instances)

View File

@@ -0,0 +1,57 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
from mscclpp.language import *
from mscclpp.language.collectives import SendRecv
from mscclpp.language.buffer import Buffer
from mscclpp.language.types import ChannelType
def send_recv(instances):
"""
Send and receive data between two ranks using proxy channels, with LL protocol and double scratch buffer.
Steps:
1. Each rank sends a chunk to every other rank's scratch buffer with packet format via proxy channel.
2. Wait for the data to be received, then copy it to the output buffer.
"""
size = 2
chunksperloop = 1
collective = SendRecv(size, chunksperloop, False)
with MSCCLPPProgram(
"send_recv",
collective,
size,
instances,
protocol="LL",
use_double_scratch_buffer=True,
):
for r in range(size):
for nghr in range(size):
if nghr == r:
continue
c = chunk(r, Buffer.input, 0)
c.put_packet(
nghr,
"scratch",
1,
sendtb=0,
chan_type=ChannelType.proxy,
temp_buffer="scratch",
temp_buffer_index=0,
)
for r in range(size):
c = chunk(r, "scratch", 1)
c.copy_packet(r, Buffer.output, 0, sendtb=0)
Json()
Check()
parser = argparse.ArgumentParser()
parser.add_argument("instances", type=int, help="number of instances")
args = parser.parse_args()
send_recv(args.instances)

View File

@@ -0,0 +1,56 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
from mscclpp.language import *
from mscclpp.language.buffer import Buffer
from mscclpp.language.collectives import SendRecv
from mscclpp.language.types import ChannelType
def send_recv(instances):
"""
Send and receive data between two ranks using proxy channels.
steps:
1. Each rank sends a chunk to the other rank's scratch buffer and signals the other rank that the data has been sent.
2. Wait for the data to be received then copy it to the output buffer.
"""
size = 2
chunksperloop = 1
collective = SendRecv(size, chunksperloop, False)
with MSCCLPPProgram(
"send_recv",
collective,
size,
instances,
):
for r in range(size):
for nghr in range(size):
if nghr == r:
continue
c = chunk(r, Buffer.input, 0)
c.put(
nghr,
"scratch",
1,
sendtb=0,
chan_type=ChannelType.proxy,
)
c.signal(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.proxy)
c.flush(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.proxy)
for r in range(size):
c = chunk(r, "scratch", 1)
c.wait(1 - r, Buffer.input, 0, recvtb=0, chan_type=ChannelType.proxy)
c.copy(r, Buffer.output, 0, sendtb=0)
Json()
Check()
parser = argparse.ArgumentParser()
parser.add_argument("instances", type=int, help="number of instances")
args = parser.parse_args()
send_recv(args.instances)

View File

@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from mscclpp.language.program import MSCCLPPProgram, Json, Check, chunk, rank

View File

@@ -0,0 +1,58 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from enum import Enum
# Scratch buffer slice with manual indexing
class BufferSlice:
def __init__(self, buf, name):
self.name = name
self.buf = buf
self.offset = -1 # Offset into the global scratch buffer
self.chunks = []
# Returns the global index into the scratch buffer
def get_global_index(self, index):
assert self.offset > -1, "set_offset needs to be called first"
return self.offset + index
def get_buffer(self):
return self.buf
def instance_size(self):
return len(self.chunks)
def set_offset(self, offset):
self.offset = offset
def __getitem__(self, index):
return self.chunks[index]
def __setitem__(self, index, value):
current_size = len(self.chunks)
while index > current_size:
self.chunks.append(None)
current_size = len(self.chunks)
if index == current_size:
self.chunks.append(value)
else:
self.chunks[index] = value
def __len__(self):
return len(self.chunks)
class Buffer(Enum):
input = "i"
output = "o"
scratch = "s"
def __str__(self):
return self.value
def __lt__(self, other):
return self.value < other.value
def __gt__(self, other):
return self.value < other.value

View File

@@ -0,0 +1,64 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from dataclasses import dataclass
@dataclass
class Chunk:
origin_rank: int # Rank the chunk initially started at
origin_index: int # Index the chunk initially started at
dst_rank: int = -1
dst_index: int = -1
def reduce(self, dst, chunk):
if type(chunk) is ReduceChunk:
return chunk.reduce(dst, self)
elif type(chunk) is Chunk:
chunks = [self, chunk]
return ReduceChunk(dst, chunks)
else:
raise ValueError("Trying to reduce with chunk of None")
def __hash__(self):
return hash((self.origin_rank, self.origin_index))
def __eq__(self, other):
return (
type(other) is Chunk and self.origin_rank == other.origin_rank and self.origin_index == other.origin_index
)
def __lt__(self, other):
return self.origin_rank < other.origin_rank or (
self.origin_rank == other.origin_rank and self.origin_index < other.origin_index
)
@dataclass
class ReduceChunk:
creation_rank: int # Rank the Reduce Chunk is created. Necessary since the same ReduceChunk can be created on multiple ranks independently
chunks: list # List of chunks reduced
def reduce(self, dst, chunk):
if type(chunk) is ReduceChunk:
chunks = self.chunks + chunk.chunks
elif type(chunk) is Chunk:
chunks = self.chunks + [chunk]
else:
raise ValueError("Trying to reduce with chunk of None")
return ReduceChunk(self.creation_rank, chunks)
def sort(self):
self.chunks.sort()
def __hash__(self):
self.sort()
return hash((self.creation_rank,) + tuple(self.chunks))
# Two reduce chunks are equal if they contain the same list of
# chunks being reduced
def __eq__(self, other):
self.sort()
other.sort()
return self.chunks == other.chunks

View File

@@ -0,0 +1,339 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from mscclpp.language.buffer import Buffer
from mscclpp.language.chunk import Chunk, ReduceChunk
class Collective:
def __init__(self, num_ranks, chunk_factor, inplace, num_ranks_per_node=-1, **kwargs):
self.num_ranks = num_ranks
self.chunk_factor = chunk_factor
self.inplace = inplace
self.name = "custom"
# Divide the buffer into num_chunk_groups group
if num_ranks_per_node == -1:
self.num_ranks_per_node = num_ranks
else:
self.num_ranks_per_node = num_ranks_per_node
# kwargs
# Number of chunk groups: which means we will group n chunks into m groups.
# We will guarantee that the group size is the same.
# But in the same group, the chunk size may be different due to group size
# can not be divided by the number of chunks.
self.num_chunk_groups = kwargs.get("num_chunk_groups", 1)
def init_buffers(self):
pass
def check(self, prog):
pass
def get_buffer_index(self, rank, buffer, index):
return buffer, index
class AllToAll(Collective):
def __init__(self, num_ranks, chunk_factor, inplace):
Collective.__init__(self, num_ranks, chunk_factor, inplace)
self.name = "alltoall"
def init_buffers(self):
chunks_per_node = self.num_ranks * self.chunk_factor
rank_buffers = []
for r in range(self.num_ranks):
input_buffer = [None] * chunks_per_node
output_buffer = [None] * chunks_per_node
for index in range(chunks_per_node):
chunk = Chunk(r, index, index // self.chunk_factor, index % self.chunk_factor + r * self.chunk_factor)
input_buffer[index] = chunk
if self.inplace:
buffers = {Buffer.input: input_buffer, Buffer.output: input_buffer}
else:
buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer}
rank_buffers.append(buffers)
return rank_buffers
# Expected output buffer for alltoall
def check(self, prog):
chunks_per_node = self.num_ranks * self.chunk_factor
correct = True
for r in range(self.num_ranks):
output = prog.buffers[r][Buffer.output]
for i in range(self.num_ranks):
for ch in range(self.chunk_factor):
index = ch + i * self.chunk_factor
chunk = output[index]
expected_origin_index = ch + r * self.chunk_factor
if chunk is None or chunk.origin_rank != i or chunk.origin_index != expected_origin_index:
print(
f"Rank {r} chunk {index} is incorrect should be chunk({i},{expected_origin_index}) given {chunk}"
)
correct = False
return correct
class AllGather(Collective):
def __init__(self, num_ranks, chunk_factor, inplace):
Collective.__init__(self, num_ranks, chunk_factor, inplace)
self.name = "allgather"
# Initializes input buffer for an allgather
def init_buffers(self):
rank_buffers = []
if self.inplace:
# Inplace AllGather only uses the output buffer
for r in range(self.num_ranks):
output_buffer = [None] * (self.num_ranks * self.chunk_factor)
for rank in range(self.num_ranks):
for ch in range(self.chunk_factor):
output_buffer[rank * self.chunk_factor + ch] = Chunk(
rank, ch, -1, rank * self.chunk_factor + ch
)
buffers = {
Buffer.input: output_buffer[r * self.chunk_factor : (r + 1) * self.chunk_factor],
Buffer.output: output_buffer,
}
rank_buffers.append(buffers)
else:
for r in range(self.num_ranks):
input_buffer = [None] * self.chunk_factor
output_buffer = [None] * (self.num_ranks * self.chunk_factor)
for ch in range(self.chunk_factor):
input_buffer[ch] = Chunk(r, ch, -1, r * self.chunk_factor + ch)
buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer}
rank_buffers.append(buffers)
return rank_buffers
# Expected output buffer for allgather
def check(self, prog):
correct = True
buf = Buffer.output
for r in range(self.num_ranks):
output = prog.buffers[r][buf]
for i in range(self.num_ranks):
for ch in range(self.chunk_factor):
index = i * self.chunk_factor + ch
chunk = output[index]
if chunk is None:
print(f"Rank {r} chunk {index} is incorrect should be ({i}, {ch}) given None")
correct = False
elif chunk.origin_rank != i or chunk.origin_index != ch:
print(
f"Rank {r} chunk {index} is incorrect should be ({i}, {ch}) given ({chunk.origin_rank}, {chunk.origin_index})"
)
correct = False
return correct
def get_buffer_index(self, rank, buffer, index):
# For inplace AllGathers, the input buffer points into the output buffer
if self.inplace and buffer == Buffer.input:
return Buffer.output, index + rank * self.chunk_factor
else:
return buffer, index
class AllReduce(Collective):
def __init__(self, num_ranks, chunk_factor, inplace, num_ranks_per_node=-1, **kwargs):
num_chunk_groups = kwargs.get("num_chunk_groups", num_ranks)
Collective.__init__(
self, num_ranks, chunk_factor, inplace, num_ranks_per_node, num_chunk_groups=num_chunk_groups
)
self.name = "allreduce"
def init_buffers(self):
chunks_per_node = self.chunk_factor
rank_buffers = []
for r in range(self.num_ranks):
input_buffer = []
output_buffer = [None] * chunks_per_node
for c in range(chunks_per_node):
# Chunks start at rank r index c, and ends on all ranks (-1) at index r
input_buffer.append(Chunk(r, c, -1, c))
# Input and output buffer are the same.
if self.inplace:
buffers = {Buffer.input: input_buffer, Buffer.output: input_buffer}
else:
buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer}
rank_buffers.append(buffers)
return rank_buffers
def check(self, prog):
chunks_per_node = self.chunk_factor
expected_chunks = []
buf = Buffer.input if self.inplace else Buffer.output
for c in range(chunks_per_node):
chunk = ReduceChunk(-1, [])
for r in range(self.num_ranks):
chunk = chunk.reduce(-1, Chunk(r, c))
expected_chunks.append(chunk)
correct = True
for r in range(self.num_ranks):
output = prog.buffers[r][buf]
for c in range(chunks_per_node):
chunk = output[c]
if chunk is None or chunk != expected_chunks[c]:
print(
f"Rank {r} chunk {c} is incorrect should be ReduceChunk index {c} from all ranks, given {chunk}"
)
correct = False
return correct
def get_buffer_index(self, rank, buffer, index):
if self.inplace and buffer == Buffer.output:
return Buffer.input, index
else:
return buffer, index
class ReduceScatter(Collective):
def __init__(self, num_ranks, chunk_factor, inplace):
Collective.__init__(self, num_ranks, chunk_factor, inplace)
self.name = "reducescatter"
def init_buffers(self):
rank_buffers = []
for r in range(self.num_ranks):
if self.inplace:
input_buffer = []
for i in range(self.num_ranks):
for c in range(self.chunk_factor):
input_buffer.append(Chunk(r, i * self.chunk_factor + c, i, c))
buffers = {Buffer.input: input_buffer}
rank_buffers.append(buffers)
else:
input_buffer = []
output_buffer = [None] * self.chunk_factor
for i in range(self.num_ranks):
for c in range(self.chunk_factor):
input_buffer.append(Chunk(r, i * self.chunk_factor + c, i, c))
buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer}
rank_buffers.append(buffers)
return rank_buffers
def check(self, prog):
expected_chunks = []
buf = Buffer.input if self.inplace else Buffer.output
for c in range(self.num_ranks * self.chunk_factor):
chunk = ReduceChunk(-1, [])
for r in range(self.num_ranks):
chunk = chunk.reduce(-1, Chunk(r, c))
expected_chunks.append(chunk)
correct = True
for r in range(self.num_ranks):
output = prog.buffers[r][buf]
for c in range(self.chunk_factor):
correct_idx = r * self.chunk_factor + c
if self.inplace:
c = correct_idx
chunk = output[c]
if chunk is None or chunk != expected_chunks[correct_idx]:
print(f"Rank {r} chunk {c} is incorrect should be index {correct_idx} from all ranks given {chunk}")
correct = False
return correct
def get_buffer_index(self, rank, buffer, index):
# For inplace ReduceScatter the output buffer is a pointer into the input buffer
if self.inplace and buffer == Buffer.output:
return Buffer.input, index + rank * self.chunk_factor
else:
return buffer, index
# SendRecv is a collective that sends a chunk from one rank to another
# It is used to test the correctness of the send and receive instructions
class SendRecv(Collective):
def __init__(self, num_ranks, chunk_factor, inplace):
assert num_ranks == 2, "SendRecv only supports 2 ranks"
Collective.__init__(self, num_ranks, chunk_factor, inplace)
self.name = "sendrecv"
def init_buffers(self):
rank_buffers = []
for r in range(self.num_ranks):
input_buffer = [None] * self.chunk_factor
output_buffer = [None] * self.chunk_factor
for c in range(self.chunk_factor):
input_buffer[c] = Chunk(r, c, -1, c)
buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer}
rank_buffers.append(buffers)
return rank_buffers
def check(self, prog):
correct = True
buff_type = Buffer.input if self.inplace else Buffer.output
for r in range(self.num_ranks):
output = prog.buffers[r][buff_type]
for c in range(self.chunk_factor):
chunk = output[c]
if chunk is None or chunk.origin_rank != 1 - r or chunk.origin_index != c:
print(f"Rank {r} chunk {c} is incorrect should be ({1 - r}, {c}) given {chunk}")
correct = False
return correct
def get_buffer_index(self, rank, buffer, index):
if self.inplace and buffer == Buffer.output:
return Buffer.input, index
return buffer, index
class Broadcast(Collective):
def __init__(self, num_ranks, chunk_factor, inplace, root):
Collective.__init__(self, num_ranks, chunk_factor, inplace, root)
self.name = "broadcast"
self.root = root
# Initializes input buffer for an broadcast
def init_buffers(self):
rank_buffers = []
if self.inplace:
# Inplace broadcast only uses the input buffer
for r in range(self.num_ranks):
input_buffer = [None] * (self.chunk_factor)
for ch in range(self.chunk_factor):
input_buffer[ch] = Chunk(self.root, ch, -1, ch)
buffers = {
Buffer.input: input_buffer,
Buffer.output: input_buffer,
}
rank_buffers.append(buffers)
else:
for r in range(self.num_ranks):
input_buffer = [None] * self.chunk_factor
output_buffer = [None] * self.chunk_factor
if r == self.root:
for ch in range(self.chunk_factor):
input_buffer[ch] = Chunk(self.root, ch, -1, ch)
buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer}
rank_buffers.append(buffers)
return rank_buffers
# Expected output buffer for broadcast
def check(self, prog):
correct = True
buf = Buffer.output
for r in range(self.num_ranks):
output = prog.buffers[0][buf]
for ch in range(self.chunk_factor):
index = ch
chunk = output[index]
if chunk is None:
print(f"Rank {r} chunk {index} is incorrect should be ({i}, {ch}) given None")
correct = False
elif chunk.origin_rank != self.root or chunk.origin_index != ch:
print(
f"Rank {r} chunk {index} is incorrect should be ({self.root}, {ch}) given ({chunk.origin_rank}, {chunk.origin_index})"
)
correct = False
return correct
def get_buffer_index(self, rank, buffer, index):
return buffer, index

View File

@@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from mscclpp.language.dag.instruction_dag import InstructionDAG
from mscclpp.language.dag.lower import DagLower
from mscclpp.language.dag.optimizer import DagOptimizer

View File

@@ -0,0 +1,373 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from collections import defaultdict
from mscclpp.language.buffer import Buffer
from mscclpp.language.types import (
Channel,
ChannelType,
ChunkRef,
Instruction,
Op,
)
class InstructionDAG:
def __init__(self, num_ranks: int, buffers: list):
self.num_ranks = num_ranks
self.buffers = buffers
# State for the actual instruction DAG
self.operations = {} # slot -> operations
self.last_writer = {} # slot -> last writing op
self.last_readers = defaultdict(list) # slot -> list of last reading ops
# State for the MSCCLPP-IR
self.tbs = []
for _ in range(num_ranks):
self.tbs.append({})
self.tb_mapping = {}
self.num_channels = [1] * num_ranks
self.tb_steps = [{} for _ in range(num_ranks)]
def convert_set_list(self):
ops = []
visited = set()
for slot, op in self.operations.items():
if op.inst == Instruction.start:
op.next = list(op.next)
for o in op.next:
ops.append(o)
elif op.inst != Instruction.copy:
ops.append(op)
while len(ops) > 0:
op = ops[0]
if op not in visited:
visited.add(op)
op.next = list(op.next)
ops = ops[1:] + op.next
else:
ops = ops[1:]
return visited
def complete_channels(self):
send_op = [Instruction.put, Instruction.signal, Instruction.put_packet]
recv_op = [Instruction.wait, Instruction.get, Instruction.read_reduce_copy]
group_send_op = [Instruction.group_store]
group_recv_op = [Instruction.group_load_reduce]
for rank, rank_tbs in enumerate(self.tbs):
for tbid, tb in rank_tbs.items():
chans = set()
for op in tb.ops:
if op.inst == Instruction.barrier:
continue
if op.src != None:
src_buffer = (
Buffer.scratch
if op.src.buffer is not Buffer.input and op.src.buffer is not Buffer.output
else op.src.buffer
)
if op.dst != None:
dst_buffer = (
Buffer.scratch
if op.dst.buffer is not Buffer.input and op.dst.buffer is not Buffer.output
else op.dst.buffer
)
if op.channel_type == ChannelType.nvls:
if op.inst in group_send_op:
ranks = [dst[0].rank for dst in op.dsts]
chan = Channel(src_buffer, dst_buffer, op.channel_type, ranks)
chans.add(chan)
elif op.inst in group_recv_op:
ranks = [src[0].rank for src in op.srcs]
chan = Channel(src_buffer, dst_buffer, op.channel_type, ranks)
chans.add(chan)
else:
if op.inst in send_op:
chan = Channel(src_buffer, dst_buffer, op.channel_type, op.dst.rank)
chans.add(chan)
elif op.inst in recv_op:
chan = Channel(src_buffer, dst_buffer, op.channel_type, op.src.rank)
chans.add(chan)
tb.channels = list(chans)
# InstructionDAG - builds the roots of the DAG
def add_start(self, rank, buffer, index, ref):
slot = (rank, buffer, index)
op = Op(Instruction.start, rank, ref, ref, next=set(), prev=set())
self.operations[slot] = op
self.last_writer[slot] = op
# InstructionDAG - adds a copy node
def add_copy(self, rank, send_ref, recv_ref, tb, trans_from_packet=False, trans_to_packet=False):
tb_step = self._get_tb_step(rank, tb)
if trans_from_packet:
op = Op(Instruction.copy_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step)
elif trans_to_packet:
op = Op(
Instruction.transform_to_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step
)
else:
op = Op(Instruction.copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step)
dstbuffer = recv_ref.buffer
dstindex = recv_ref.index
srcbuffer = send_ref.buffer
srcindex = send_ref.index
size = recv_ref.size
# Sending part of copy [Read]
self._read(rank, srcbuffer, srcindex, size, op)
# Receiving part of copy [Write]
self._write(rank, dstbuffer, dstindex, size, op)
return op
# InstructionDAG - adds a redduce node
def add_reduce(self, rank, send_ref, recv_ref, tb, use_packet=False):
tb_step = self._get_tb_step(rank, tb)
if use_packet:
op = Op(Instruction.reduce_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step)
else:
op = Op(Instruction.reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step)
dstbuffer = recv_ref.buffer
dstindex = recv_ref.index
srcbuffer = send_ref.buffer
srcindex = send_ref.index
size = recv_ref.size
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
# Sending part of reduce
self._read(rank, srcbuffer, srcindex, size, op)
# Reduce part of copy
self._write(rank, dstbuffer, dstindex, size, op, read=True)
return op
# InstructionDAG - adds a put node
def add_put(self, rank, send_ref, recv_ref, tb, ch_type, use_packet=False):
tb_step = self._get_tb_step(rank, tb)
if use_packet:
op = Op(
Instruction.put_packet,
rank,
send_ref,
recv_ref,
next=set(),
prev=set(),
tb=tb,
channel_type=ch_type,
step=tb_step,
)
else:
op = Op(
Instruction.put,
rank,
send_ref,
recv_ref,
next=set(),
prev=set(),
tb=tb,
channel_type=ch_type,
step=tb_step,
)
buffer = send_ref.buffer
index = send_ref.index
size = send_ref.size
self._read(rank, buffer, index, size, op)
op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step))
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
return op
def add_get(self, rank, send_ref, recv_ref, tb, ch_type):
tb_step = self._get_tb_step(rank, tb)
op = Op(
Instruction.get, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step
)
buffer = recv_ref.buffer
index = recv_ref.index
size = recv_ref.size
self._write(rank, buffer, index, size, op)
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step))
return op
# InstructionDAG - adds a signal node.
def add_signal(self, rank, send_ref, recv_ref, tb, ch_type):
tb_step = self._get_tb_step(rank, tb)
op = Op(
Instruction.signal,
rank,
send_ref,
recv_ref,
next=set(),
prev=set(),
tb=tb,
channel_type=ch_type,
step=tb_step,
)
buffer = send_ref.buffer
index = send_ref.index
size = send_ref.size
# treat signal as a write. signal acts as a barrier for the next instruction which prevents the
# below instructions to be scheduled above the signal instruction.
self._write(rank, buffer, index, size, op)
op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step))
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
return op
def add_flush(self, rank, send_ref, recv_ref, tb):
tb_step = self._get_tb_step(rank, tb)
op = Op(
Instruction.flush,
rank,
send_ref,
recv_ref,
next=set(),
prev=set(),
tb=tb,
channel_type=ChannelType.proxy,
step=tb_step,
)
buffer = send_ref.buffer
index = send_ref.index
size = send_ref.size
self._read(rank, buffer, index, size, op)
op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step))
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
return op
def add_wait(self, rank, dst_ref, src_ref, tb, ch_type):
tb_step = self._get_tb_step(rank, tb)
op = Op(
Instruction.wait, rank, src_ref, dst_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step
)
buffer = dst_ref.buffer
index = dst_ref.index
size = dst_ref.size
self._write(rank, buffer, index, size, op)
op.srcs.append((ChunkRef(src_ref.rank, src_ref.buffer, src_ref.index, src_ref.size), tb_step))
op.dsts.append((ChunkRef(dst_ref.rank, dst_ref.buffer, dst_ref.index, dst_ref.size), tb_step))
return op
def add_read_reduce(self, rank, send_ref, recv_ref, tb, ch_type):
tb_step = self._get_tb_step(rank, tb)
op = Op(
Instruction.read_reduce_copy,
rank,
send_ref,
recv_ref,
next=set(),
prev=set(),
tb=tb,
channel_type=ch_type,
step=tb_step,
)
buffer = recv_ref.buffer
index = recv_ref.index
size = recv_ref.size
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
self._write(rank, buffer, index, size, op, read=True)
return op
def add_barrier(self, rank, tb_list, barrier_id):
buffers = self.buffers[rank]
for tb in tb_list:
tb_step = self._get_tb_step(rank, tb)
extra = {"tb_list": tb_list, "barrier_id": barrier_id}
op = Op(Instruction.barrier, rank, None, None, next=set(), prev=set(), tb=tb, step=tb_step, extra=extra)
for buffer_type, buffer in buffers.items():
self._write(rank, buffer_type, 0, len(buffer), op)
def add_group_load_reduce(self, rank, send_refs, recv_ref, tb, ch_type):
tb_step = self._get_tb_step(rank, tb)
op = Op(
Instruction.group_load_reduce,
rank,
recv_ref,
recv_ref,
next=set(),
prev=set(),
tb=tb,
channel_type=ch_type,
step=tb_step,
)
# treat recv_ref as src for group_load_reduce
op.srcs.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step))
for send_ref in send_refs:
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
buffer = recv_ref.buffer
index = recv_ref.index
size = recv_ref.size
self._write(rank, buffer, index, size, op, read=True)
def add_group_store(self, rank, send_ref, recv_refs, tb, ch_type):
tb_step = self._get_tb_step(rank, tb)
op = Op(
Instruction.group_store,
rank,
send_ref,
send_ref,
next=set(),
prev=set(),
tb=tb,
channel_type=ch_type,
step=tb_step,
)
# treat send_ref as dst for group_store
op.dsts.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
for recv_ref in recv_refs:
op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step))
buffer = send_ref.buffer
index = send_ref.index
size = send_ref.size
self._read(rank, buffer, index, size, op)
return op
def _get_tb_step(self, rank: int, tb: int):
if tb in self.tb_steps[rank]:
self.tb_steps[rank][tb] += 1
return self.tb_steps[rank][tb]
else:
self.tb_steps[rank][tb] = 0
return 0
# InstructionDAG helper - identifies the dependencies for a write-type operation (recv, copy, rrc, reduce)
def _write(self, rank, buffer, index, size, op, read=False):
prev_ops = set()
for i in range(index, index + size):
slot = (rank, buffer, i)
if read:
assert slot in self.last_writer, f"Destination slot has never been written before a reduce {op}"
# First write to this slot
if slot not in self.operations:
self.operations[slot] = op
# If there are active readers - these are the previous operations
# Else the previous operation is the last write (if there is one)
readers = self.last_readers[slot]
if len(readers) > 0:
prev_ops.update(readers)
elif slot in self.last_writer:
prev_ops.add(self.last_writer[slot])
# Set the last_writer to this op, and clear all readers
self.last_writer[slot] = op
self.last_readers[slot] = []
# Update the next pointer of the previous ops
for prev_op in prev_ops:
prev_op.next.add(op)
op.prev.add(prev_op)
# InstructionDAG helper - identifies the dependencies for read-type operations (send, copy, reduce)
def _read(self, rank, buffer, index, size, op):
prev_ops = set()
for i in range(index, index + size):
slot = (rank, buffer, i)
assert slot in self.last_writer, f"Slot has never been written before a read-type {op}"
# The previous operation for a reader is the last write to the slot
writer = self.last_writer[slot]
prev_ops.add(writer)
self.last_readers[slot].append(op)
# Update the next pointer of the previous ops
for prev_op in prev_ops:
prev_op.next.add(op)
op.prev.add(prev_op)

View File

@@ -0,0 +1,162 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import copy
from typing import List
from mscclpp.language.buffer import Buffer
from mscclpp.language.dag.instruction_dag import InstructionDAG
from mscclpp.language.types import ChunkRef, Gpu, Instruction, Op, ReplicationPolicy, Threadblock
class DagLower:
def __init__(self, dag: InstructionDAG):
self.dag = dag
self.instanced_tbs = []
def lower(self, instances: int, replication_policy: ReplicationPolicy):
self._infer_dependencies()
self._lower_buffers(instances)
self._replicate(instances, replication_policy)
return self._lower_tbs()
def _replicate(self, instances: int, replication_policy: ReplicationPolicy):
# update op step
for rank, rank_tbs in enumerate(self.dag.tbs):
for _, tb in rank_tbs.items():
for id, op in enumerate(tb.ops):
op.step = id
if instances == 1:
self.instanced_tbs = self.dag.tbs
return
self.instanced_tbs = []
for _ in range(self.dag.num_ranks):
self.instanced_tbs.append({})
def get_new_index(rank, buffer, index, size, i):
if replication_policy == ReplicationPolicy.interleaved:
return index * instances + i * size
return len(self.dag.buffers[rank][buffer]) * i + index
def get_instance_ref(ref):
if ref is None:
return None
iindex = get_new_index(ref.rank, ref.buffer, ref.index, ref.size, i)
iref = ChunkRef(ref.rank, ref.buffer, iindex, ref.size)
return iref
def update_extra(op, ori_op):
if op.inst == Instruction.barrier:
tb_list = ori_op.extra["tb_list"]
new_tb_list = [tb * instances + i for tb in tb_list]
op.extra["tb_list"] = new_tb_list
op.extra["barrier_id"] = ori_op.extra["barrier_id"] * instances + i
for i in range(instances):
# Generate all the threadblocks and ops
for rank, rank_tbs in enumerate(self.dag.tbs):
# rank_channels = self.num_channels[rank]
for tbid, tb in rank_tbs.items():
itbid = tbid * instances + i
itb = Threadblock(id=itbid)
itb.ops = [None] * len(tb.ops)
for s, op in enumerate(tb.ops):
isrc = get_instance_ref(op.src)
idst = get_instance_ref(op.dst)
idepends = []
# Note: We don't need the fill out the rest of the metadata since replication is the last optimization
iop = Op(
op.inst,
op.rank,
isrc,
idst,
idepends,
op.step,
itbid,
channel_type=op.channel_type,
extra=copy.deepcopy(op.extra),
)
update_extra(iop, op)
itb.ops[s] = iop
for src, step in op.srcs:
isrc = get_instance_ref(src)
iop.srcs.append((isrc, step))
for dst, step in op.dsts:
idst = get_instance_ref(dst)
iop.dsts.append((idst, step))
for chan in tb.channels:
itb.channels.append(chan)
self.instanced_tbs[op.rank][itbid] = itb
# Redo dependency analysis
for rank, rank_tbs in enumerate(self.dag.tbs):
for tbid, tb in rank_tbs.items():
for i in range(instances):
itbid = tbid * instances + i
itb = self.instanced_tbs[rank][itbid]
for op, iop in zip(tb.ops, itb.ops):
iop.depends = [None] * len(op.depends)
for s, dep in enumerate(op.depends):
dep_tbid = dep.tb
dep_itbid = dep_tbid * instances + i
dep_step = dep.step
iop.depends[s] = self.instanced_tbs[op.rank][dep_itbid].ops[dep_step]
# Convert local scratch buffers to index into one global scratch buffer
def _lower_chunk(self, chunk):
if chunk is not None and chunk.buffer is not Buffer.input and chunk.buffer is not Buffer.output:
buffer = self.dag.buffers[chunk.rank][chunk.buffer].get_buffer()
index = self.dag.buffers[chunk.rank][chunk.buffer].get_global_index(chunk.index)
return ChunkRef(chunk.rank, buffer, index, chunk.size)
return chunk
# Assigns each scratch buffer an offset into the global scratch buffer
def _lower_buffers(self, instances):
for rank_buffers in self.dag.buffers:
offset = 0
for key, buf in rank_buffers.items():
if key is not Buffer.input and key is not Buffer.output:
buf.set_offset(offset)
offset += buf.instance_size() * instances
def _lower_tbs(self) -> List[Gpu]:
gpus = []
for rank, rank_tbs in enumerate(self.instanced_tbs):
lowered_tbs = {}
for tbid, tb in rank_tbs.items():
for op in tb.ops:
op.src = self._lower_chunk(op.src)
op.dst = self._lower_chunk(op.dst)
srcs = sorted(op.srcs, key=lambda x: x[1])
dsts = sorted(op.dsts, key=lambda x: x[1])
op.srcs = [self._lower_chunk(src[0]) for src in srcs]
op.dsts = [self._lower_chunk(dst[0]) for dst in dsts]
lowered_tbs[tbid] = tb
gpus.append(Gpu(rank, list(lowered_tbs.values())))
return gpus
def _infer_dependencies(self):
visited = set()
for _, op in self.dag.operations.items():
if op in visited:
continue
frontier = [op]
while len(frontier) > 0:
op = frontier[0]
if op in visited:
frontier = frontier[1:]
continue
# Dependencies for every op is the same as the ops that are stored in prev
# Filter out dependencies that are satisified by tbs executing ops sequentially
# If multiple dependent ops from the same tb keep the one that happens last
depends = {}
for dep_op in list(op.prev):
if dep_op.inst != Instruction.start:
tb = dep_op.tb
if tb not in depends or dep_op.step > depends[tb].step:
depends[tb] = dep_op
op.depends = list(depends.values())
visited.add(op)
frontier = frontier[1:] + op.next

View File

@@ -0,0 +1,405 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from mscclpp.language.utils import (
buf_dst_src_match,
circular_dep_after_merge,
merge_op,
remove_op,
same_chan_type,
same_count,
same_buf_dst,
same_buf_src,
same_src_dst_buffer_type,
same_tb,
all_prevs_visited_after_merge,
)
from mscclpp.language.dag.instruction_dag import InstructionDAG
from mscclpp.language.types import ChunkRef, ChannelType, Instruction, Op, Threadblock
class _InstructionOptimizer:
def try_merge_same_instructions(
self,
op: Op,
next_op: Op,
tb: Threadblock,
queue: list,
expected_next_inst: Instruction,
same_buf_func: callable,
) -> bool:
"""
Attempts to merge two instruction if conditions are met.
:param op: The current operation.
:param next_op: The next operation to potentially merge with.
:param tb: The thread block containing the operations.
:param queue: The queue of operations.
:param expected_next_inst: The instruction type expected for the next operation.
:param same_buf_func: The function to check if the buffer is the same (same_buf_dst or same_buf_src).
:return: True if operations are merged, False otherwise.
"""
if (
next_op.inst == expected_next_inst
and same_tb(op, next_op)
and same_buf_func(op, next_op)
and same_count(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
and all_prevs_visited_after_merge(op, next_op)
):
# Append the source chunks from next_op
op.srcs.append(
(
ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size),
next_op.step,
)
)
# For 'signal' and 'wait' instructions, append destination chunks too
if expected_next_inst in [Instruction.signal, Instruction.wait, Instruction.flush]:
op.dsts.append(
(
ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size),
next_op.step,
)
)
# Merge operations
merge_op(op, next_op)
tb.ops.remove(next_op)
queue.remove(next_op)
return True
return False
def try_compact_instructions(
self, op: Op, tb: Threadblock, queue: list, inst_type: Instruction, same_src_dst_func: callable
) -> bool:
"""
Try to campact the instructions with the same instruction type. This optimization will
compact multiple instructions of the same type into a single instruction.
:param op: The current operation.
:param seq_op: The sequential operation to merge with.
:param tb: The task block containing the operations.
:param queue: The queue of operations.
:param inst_type: The type of the instruction being processed (get, put, put_packet).
:return: True if operations are merged, False otherwise.
"""
if len(queue) > 1:
seq_op = queue[1]
if (
seq_op.inst == inst_type
and same_src_dst_func(op, seq_op)
and same_chan_type(op, seq_op)
and same_count(op, seq_op)
and not circular_dep_after_merge(op, seq_op)
and all_prevs_visited_after_merge(op, seq_op)
):
# Append the source and destination chunks from seq_op
op.dsts.append(
(
ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size),
seq_op.step,
)
)
op.srcs.append(
(
ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size),
seq_op.step,
)
)
merge_op(op, seq_op)
tb.ops.remove(seq_op)
queue.remove(seq_op)
return True
return False
def try_fuse_with_put(self, op: Op, next_op: Op, tb: Threadblock, queue: list) -> bool:
"""
Attempts to fuse 'put' operations with other operations like read_reduce_copy, reduce, etc.
:param op: The current operation.
:param next_op: The next operation to potentially merge with.
:param tb: The thread block containing the operations.
:param queue: The queue of operations.
:param inst_type: The type of the instruction being processed.
:param chan_type: Channel type if applicable.
:return: True if operations are merged, False otherwise.
"""
if (
(next_op.inst == Instruction.put or next_op.inst == Instruction.put_packet)
and same_tb(op, next_op)
and same_count(op, next_op)
and buf_dst_src_match(op, next_op)
and next_op.channel_type == ChannelType.sm
and (op.channel_type == ChannelType.none or op.channel_type == ChannelType.sm)
and not circular_dep_after_merge(op, next_op)
and all_prevs_visited_after_merge(op, next_op)
):
if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer:
return False
# Adjust instruction type and channel if needed
if op.inst == Instruction.read_reduce_copy:
op.inst = Instruction.read_reduce_copy_send
elif op.inst == Instruction.reduce:
op.inst = Instruction.reduce_send
op.channel_type = ChannelType.sm
elif op.inst == Instruction.reduce_packet:
op.inst = Instruction.reduce_send_packet
op.channel_type = ChannelType.sm
# Append the destination chunk from next_op
op.dsts.append(
(
ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size),
next_op.step,
)
)
# Merge operations
merge_op(op, next_op)
tb.ops.remove(next_op)
queue.remove(next_op)
return True
return False
def try_fuse_instructions_using_proxy_channel(
self, op: Op, next_op: Op, tb: Threadblock, queue: list, expected_next_inst: Instruction
) -> bool:
"""
Attempts to fuse operations which using proxy channel.
:param op: The current operation.
:param next_op: The next operation to potentially merge with.
:param tb: The thread block containing the operations.
:param queue: The queue of operations.
:param expected_next_inst: The instruction type expected for the next operation.
:return: True if operations are merged, False otherwise.
"""
if (
next_op.inst == expected_next_inst
and same_tb(op, next_op)
and same_count(op, next_op)
and same_buf_dst(op, next_op)
and same_buf_src(op, next_op)
and same_chan_type(op, next_op)
and op.channel_type == ChannelType.proxy
and not circular_dep_after_merge(op, next_op)
and all_prevs_visited_after_merge(op, next_op)
):
if op.inst == Instruction.put and next_op.inst == Instruction.signal:
op.inst = Instruction.put_with_signal
elif op.inst == Instruction.put_with_signal and next_op.inst == Instruction.flush:
op.inst = Instruction.put_with_signal_and_flush
# Merge operations
merge_op(op, next_op)
tb.ops.remove(next_op)
queue.remove(next_op)
return True
return False
def try_fuse_with_group_store(self, op: Op, next_op: Op, tb: Threadblock, queue: list) -> bool:
"""
Attempts to fuse 'gruop_load_reduce' operations with 'group_store' operations.
:param op: The current operation.
:param next_op: The next operation to potentially merge with.
:param tb: The thread block containing the operations.
:param queue: The queue of operations.
:return: True if operations are merged, False otherwise.
"""
if (
next_op.inst == Instruction.group_store
and same_count(op, next_op)
and buf_dst_src_match(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
and all_prevs_visited_after_merge(op, next_op)
):
# Append the destination chunk from next_op
op.inst = Instruction.group_load_reduce_store
op.src = next_op.src
for dst in next_op.dsts:
op.dsts.append(dst)
# Merge operations
merge_op(op, next_op)
tb.ops.remove(next_op)
queue.remove(next_op)
return True
return False
def try_remove_op(self, pending_remove_op: Op, condition: bool) -> bool:
if condition:
remove_op(pending_remove_op)
return True
return False
class DagOptimizer:
def __init__(self, instruction_dag: InstructionDAG):
self.optimizer = _InstructionOptimizer()
self.dag = instruction_dag
def remove_redundant_signal_wait(self):
# For packet ops, we can remove signal/wait
for rank, rank_tbs in enumerate(self.dag.tbs):
for tbid, tb in rank_tbs.items():
queue = list(tb.ops)
while len(queue) > 0:
op = queue[0]
fused = False
if op.inst == Instruction.put_packet:
for next_op in op.next:
fused = self.optimizer.try_remove_op(next_op, next_op.inst == Instruction.signal)
if fused:
break
elif op.inst == Instruction.reduce_packet or op.inst == Instruction.copy_packet:
for prev_op in op.prev:
fused = self.optimizer.try_remove_op(prev_op, prev_op.inst == Instruction.wait)
if fused:
break
if fused:
continue
queue = queue[1:]
def fuse_instructions(self):
self._fuse_instructions_using_proxy_channel()
self._fuse_same_instructions()
self._optimize_rrcs_rs()
self._optimize_group_ops()
self._compact_instructions()
# put(src, sbuf, si, dst, dbuf, di) signal(src, sbuf, si, dst, dbuf, di)
# -> putWithSignal(src, sbuf, si, dst, dbuf, di)
# put(src, sbuf, si, dst, dbuf, di) signal(src, sbuf, si, dst, dbuf, di) flush(src, sbuf, si, dst, dbuf, di)
# -> putWithSignalAndFlush(src, sbuf, si, dst, dbuf, di)
def _fuse_instructions_using_proxy_channel(self):
inst_followup_map = {
Instruction.put: Instruction.signal,
Instruction.put_with_signal: Instruction.flush,
}
for rank, rank_tbs in enumerate(self.dag.tbs):
for tbid, tb in rank_tbs.items():
queue = list(tb.ops)
while len(queue) > 0:
op = queue[0]
fused = False
if op.inst in inst_followup_map:
for next_op in op.next:
fused = self.optimizer.try_fuse_instructions_using_proxy_channel(
op, next_op, tb, queue, inst_followup_map[op.inst]
)
if fused:
break
if fused:
continue
queue = queue[1:]
# rrc(_,_,_,dst,dbuf,di) rrc(_,_,_,dst,dbuf,di) -> rrc(list[src,sbuf,si], dst, dbuf, di)
# signal(_,_,_,dst,dbuf,di) signal(_,_,_,dst,dbuf,di) -> signal(_,_,_,list[dst,dbuf,di])
# wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_])
# reduce(_,_,_,dst,dbuf,di) reduce(_,_,_,dst,dbuf,di) -> reduce(list[src,sbuf,si], dst, dbuf, di)
# reduce_packet(_,_,_,dst,dbuf,di) reduce_packet(_,_,_,dst,dbuf,di) -> reduce_packet(list[src,sbuf,si], dst, dbuf, di)
def _fuse_same_instructions(self):
# Mapping instruction to their respective condition checks and same buffer function
instruction_handlers = {
Instruction.read_reduce_copy: same_buf_dst,
Instruction.reduce: same_buf_dst,
Instruction.reduce_packet: same_buf_dst,
Instruction.signal: same_buf_src,
Instruction.wait: same_buf_dst,
}
for _, rank_tbs in enumerate(self.dag.tbs):
for _, tb in rank_tbs.items():
queue = list(tb.ops)
while len(queue) > 0:
op = queue[0]
fused = False
inst_type = op.inst
if inst_type in instruction_handlers:
for next_op in op.next:
same_buf_func = instruction_handlers[inst_type]
if self.optimizer.try_merge_same_instructions(
op, next_op, tb, queue, inst_type, same_buf_func
):
fused = True
break
if fused:
continue
queue = queue[1:]
# rrc(_,_,_,dst,dbuf,di) put(dst,dbuf,di,_,_,_) -> rrcs(_,_,_,_,_,_)
# reduce(_,_,_,dst,dbuf,di) put(dst,dbuf,di,_,_,_) -> rs(_,_,_,_,_,_)
def _optimize_rrcs_rs(self):
inst_types = [
Instruction.read_reduce_copy,
Instruction.reduce,
Instruction.reduce_packet,
Instruction.read_reduce_copy_send,
Instruction.reduce_send,
Instruction.reduce_send_packet,
]
for _, rank_tbs in enumerate(self.dag.tbs):
for _, tb in rank_tbs.items():
queue = list(tb.ops)
while len(queue) > 0:
op = queue[0]
fused = False
if op.inst in inst_types:
for next_op in op.next:
fused = self.optimizer.try_fuse_with_put(op, next_op, tb, queue)
if fused:
break
if fused:
continue
queue = queue[1:]
# glre(srcs, sbuf, si, _, _, _), gstore (_, _, _, dsts, dbuf, di) -> glres(srcs, sbuf, si, dsts, dbuf, di)
def _optimize_group_ops(self):
inst_types = [
Instruction.group_load_reduce,
]
for _, rank_tbs in enumerate(self.dag.tbs):
for _, tb in rank_tbs.items():
queue = list(tb.ops)
while len(queue) > 0:
op = queue[0]
fused = False
if op.inst in inst_types:
for next_op in op.next:
fused = self.optimizer.try_fuse_with_group_store(op, next_op, tb, queue)
if fused:
break
if fused:
continue
queue = queue[1:]
# merge ops which are independent of other operations and no other operations in between
# get(src, sbuf. si, dst, dbuf, di) get(src, sbuf, si, dst, dbuf, di) -> get(list[src,sbuf,si], list[dst,dbuf,di])
# put(src, sbuf, si, dst, dbuf, di) put(src, sbuf, si, dst, dbuf, di) -> put(list[src,sbuf,si], list[dst,dbuf,di])
# putWithSignal/putWithSignalAndFlush(src, sbuf, si, dst, dbuf, di)
# putWithSignal/putWithSignalAndFlush(src, sbuf, si, dst, dbuf, di)
# -> putWithSignal/putWithSignalAndFlush(list[src,sbuf,si], list[dst,dbuf,di])
# wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_])
def _compact_instructions(self):
campactable_inst = [
Instruction.get,
Instruction.put,
Instruction.put_packet,
Instruction.put_with_signal,
Instruction.put_with_signal_and_flush,
Instruction.signal,
Instruction.flush,
Instruction.wait,
]
for _, rank_tbs in enumerate(self.dag.tbs):
for _, tb in rank_tbs.items():
if tb.id == -1:
continue
queue = list(tb.ops)
while len(queue) > 0:
op = queue[0]
fused = False
if op.inst in campactable_inst:
fused = self.optimizer.try_compact_instructions(
op, tb, queue, op.inst, same_src_dst_buffer_type
)
if fused:
continue
queue = queue[1:]

View File

@@ -0,0 +1,534 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import asdict, dataclass
import json
from typing import Dict, List, Optional, Union
from mscclpp.language.types import Buffer, ChannelType, Op, Program, Instruction
_local_src_insts_mscclpp: set = {
Instruction.put,
Instruction.put_packet,
Instruction.signal,
Instruction.flush,
Instruction.put_with_signal,
Instruction.put_with_signal_and_flush,
Instruction.copy,
Instruction.copy_packet,
Instruction.transform_to_packet,
Instruction.reduce,
Instruction.reduce_packet,
Instruction.reduce_send,
Instruction.reduce_send_packet,
Instruction.group_load_reduce_store,
Instruction.group_store,
}
_local_dst_insts_mscclpp: set = {
Instruction.get,
Instruction.wait,
Instruction.read_reduce_copy,
Instruction.copy,
Instruction.copy_packet,
Instruction.transform_to_packet,
Instruction.reduce,
Instruction.read_reduce_copy_send,
Instruction.reduce_send,
Instruction.reduce_packet,
Instruction.reduce_send_packet,
Instruction.group_load_reduce_store,
Instruction.group_load_reduce,
}
_insts_no_need_sync_barrier: set = {
Instruction.copy_packet,
Instruction.reduce_packet,
Instruction.reduce_send_packet,
Instruction.barrier,
}
def ir_to_json(program: Program):
# Figure out sizes of buffers based on usage
buffer_sizes = defaultdict(lambda: 0)
for gpu in program.gpus:
for tb in gpu.threadblocks:
for op in tb.ops:
if op.inst in _local_src_insts_mscclpp:
key = (gpu.rank, op.src.buffer)
buffer_sizes[key] = max(buffer_sizes[key], op.src.index + op.src.size)
for src in op.srcs:
key = (gpu.rank, src.buffer)
buffer_sizes[key] = max(buffer_sizes[key], src.index + src.size)
if op.inst in _local_dst_insts_mscclpp:
key = (gpu.rank, op.dst.buffer)
buffer_sizes[key] = max(buffer_sizes[key], op.dst.index + op.dst.size)
# ignore remote buffers
if (
op.inst != Instruction.read_reduce_copy_send
and op.inst != Instruction.reduce_send
and op.inst != Instruction.reduce_send_packet
):
for dst in op.dsts:
key = (gpu.rank, dst.buffer)
buffer_sizes[key] = max(buffer_sizes[key], dst.index + dst.size)
for gpu in program.gpus:
gpu.input_chunks = max(buffer_sizes[(gpu.rank, Buffer.input)], gpu.input_chunks)
gpu.output_chunks = max(buffer_sizes[(gpu.rank, Buffer.output)], gpu.output_chunks)
gpu.scratch_chunks = max(buffer_sizes[(gpu.rank, Buffer.scratch)], gpu.scratch_chunks)
# Since LL protocol will double the scratch size. We need to make sure all GPUs have the same scratch size.
# Otherwise the offset calculation will be wrong.
if program.protocol == "LL":
max_scratch = max(gpu.scratch_chunks for gpu in program.gpus)
for gpu in program.gpus:
gpu.scratch_chunks = max_scratch
# get channel info for each GPU and threadblock
for gpu in program.gpus:
gpu.threadblocks = sorted(gpu.threadblocks, key=lambda tb: tb.id)
chan_dict = {}
# the channel key is the tuple (srcBuffer, dstBuffer, type)
for tb in gpu.threadblocks:
for ch in tb.channels:
key = (ch.srcBuffer, ch.dstBuffer, ch.type)
if key not in chan_dict:
chan_dict[key] = [(tb.id, ch.connected_to)]
else:
chan_dict[key].append((tb.id, ch.connected_to))
for key, value in chan_dict.items():
chan_dict[key] = sorted(value)
gpu.channels = chan_dict
# Remove the dependencies of wait after signal. They are actually depends on remote chunk
for gpu in program.gpus:
for tb in gpu.threadblocks:
for op in tb.ops:
if op.inst == Instruction.wait:
op.depends = list(filter(lambda dep: dep.inst != Instruction.signal, op.depends))
# Filter out redundant dependencies
# e.g. if op1 and op2 depend on op, and op1 happens before op2
# then op2 does not need to explicitly depend on op
for gpu in program.gpus:
for tb in gpu.threadblocks:
running_depends = []
for op in tb.ops:
op.depends = list(filter(lambda dep: dep not in running_depends, op.depends))
running_depends = running_depends + op.depends
# Do some additional postprocessing of operations:
# - Expand operations with dependencies with no-ops
for gpu in program.gpus:
for tb in gpu.threadblocks:
new_ops = []
for op in tb.ops:
if op.inst in _insts_no_need_sync_barrier:
new_ops.append(op)
continue
# Expand extra dependencies into nop operations
nop = Op(Instruction.nop, -1, None, None, [])
for i, dep in enumerate(op.depends):
# barrier already syncs all threads
if dep.inst != Instruction.barrier:
nop.depends.append(dep)
if len(new_ops) > 0 and (
new_ops[-1].inst == Instruction.barrier or new_ops[-1].inst == Instruction.nop
):
new_ops[-1].depends.extend(nop.depends)
elif len(nop.depends) > 0:
new_ops.append(nop)
new_ops.append(op)
tb.ops = new_ops
# update step and tid for ops
for gpu in program.gpus:
for tb in gpu.threadblocks:
for i, op in enumerate(tb.ops):
op.step = i
op.tb = tb.id
# Need to calculate channel info for each GPU
nchannels = 0
for gpu in program.gpus:
max_tb_channels = 0
if len(gpu.threadblocks) > 0:
max_tb_channels = max(tb.channel + 1 for tb in gpu.threadblocks)
nchannels = max(nchannels, max_tb_channels)
return _dump_to_json(program)
@dataclass
class _JsonInstruction:
name: str
i_buff: Optional[Dict[str, str]] = None
i_cids: Optional[List[Dict[str, Union[int, List[int]]]]] = None
o_buff: Optional[Dict[str, str]] = None
o_cids: Optional[List[Dict[str, Union[int, List[int]]]]] = None
src: Optional[int] = None
srcs: Optional[List[Dict[str, Union[int, str]]]] = None
srcbuff: Optional[str] = None
srcoff: Optional[int] = None
dst: Optional[int] = None
dsts: Optional[List[Dict[str, Union[int, str]]]] = None
dstbuff: Optional[str] = None
dstoff: Optional[int] = None
ctype: Optional[str] = None
cnt: Optional[int] = None
deps: Optional[List[Dict[str, int]]] = None
nthread_blocks: Optional[int] = None
barrier_id: Optional[int] = None
class _OpConverter(ABC):
def get_channel_ids(self, chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_type):
channel_ids = []
key = (src_buffer, dst_buffer, chan_type)
if chan_type == ChannelType.nvls:
ranks = []
for c in chunk_list:
ranks.append(c.rank)
channel_ids.extend(
[{"id": id} for id, ele in enumerate(tb_channel_dict[key]["connectedTo"]) if set(ele) == set(ranks)]
)
else:
for c in chunk_list:
channel_ids.extend(
[
{"id": id, "off": c.index}
for id, ele in enumerate(tb_channel_dict[key]["connectedTo"])
if ele == c.rank
]
)
return channel_ids
@abstractmethod
def to_json(self, op: Op) -> _JsonInstruction:
pass
class _SignalFlushConverter(_OpConverter):
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
dst_channel_ids = self.get_channel_ids(op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type)
o_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value}
return _JsonInstruction(
name=op.inst.value,
o_buff=o_buff,
o_cids=dst_channel_ids,
ctype=op.channel_type.value,
cnt=op.cnt(),
)
class _WaitConverter(_OpConverter):
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
src_channel_ids = self.get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type)
i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value}
return _JsonInstruction(
name=op.inst.value,
i_buff=i_buff,
i_cids=src_channel_ids,
ctype=op.channel_type.value,
cnt=op.cnt(),
)
class _ReadReduceCopyConverter(_OpConverter):
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
src_channel_ids = self.get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type)
i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value}
dst = op.dst
src = op.dst # TODO(binyli): fix this
return _JsonInstruction(
name=op.inst.value,
i_buff=i_buff,
dst=dst.rank,
dstbuff=dst.buffer.value,
dstoff=dst.index,
src=src.rank,
srcbuff=src.buffer.value,
srcoff=src.index,
i_cids=src_channel_ids,
ctype=op.channel_type.value,
cnt=op.cnt(),
)
class _ReadReduceCopySendConverter(_OpConverter):
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
src_channel_ids = self.get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type)
dst_channel_ids = self.get_channel_ids(
op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, op.channel_type
)
i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value}
o_buff = {"src": op.dst.buffer.value, "dst": op.dsts[0].buffer.value}
dst = op.dst
src = op.dst # TODO(binyli): fix this
return _JsonInstruction(
name=op.inst.value,
i_buff=i_buff,
i_cids=src_channel_ids,
o_buff=o_buff,
o_cids=dst_channel_ids,
src=src.rank,
srcbuff=src.buffer.value,
srcoff=src.index,
dst=dst.rank,
dstbuff=dst.buffer.value,
dstoff=dst.index,
ctype=op.channel_type.value,
cnt=op.cnt(),
)
class _ReduceSendConverter(_OpConverter):
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
dst_channel_ids = self.get_channel_ids(
op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, ChannelType.sm
)
o_buff = {"src": op.dst.buffer.value, "dst": op.dsts[0].buffer.value}
srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs))
dst = op.dst
src = op.dst # TODO(binyli): fix this
return _JsonInstruction(
name=op.inst.value,
o_buff=o_buff,
o_cids=dst_channel_ids,
src=src.rank,
srcbuff=src.buffer.value,
srcoff=src.index,
srcs=srcs,
dst=dst.rank,
dstbuff=dst.buffer.value,
dstoff=dst.index,
ctype=op.channel_type.value,
cnt=op.cnt(),
)
class _ReduceConverters(_OpConverter):
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs))
dst = op.dst
src = op.dst
return _JsonInstruction(
name=op.inst.value,
srcs=srcs,
dst=dst.rank,
dstbuff=dst.buffer.value,
dstoff=dst.index,
src=src.rank,
srcbuff=src.buffer.value,
srcoff=src.index,
ctype=op.channel_type.value,
cnt=op.cnt(),
)
class _NopConverter(_OpConverter):
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
return _JsonInstruction(
name=op.inst.value,
deps=list(map(lambda dep: {"tb": dep.tb, "step": dep.step}, op.depends)),
)
class _BarrierConverter(_OpConverter):
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
return _JsonInstruction(
name=op.inst.value,
nthread_blocks=len(op.extra["tb_list"]),
barrier_id=op.extra["barrier_id"],
)
class _PutConverter(_OpConverter):
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
dst_channel_ids = self.get_channel_ids(op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type)
o_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value}
srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs))
return _JsonInstruction(
name=op.inst.value,
o_buff=o_buff,
o_cids=dst_channel_ids,
srcs=srcs,
ctype=op.channel_type.value,
cnt=op.cnt(),
)
class _GetConverter(_OpConverter):
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
src_channel_ids = self.get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type)
i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value}
dsts = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.dsts))
return _JsonInstruction(
name=op.inst.value,
i_buff=i_buff,
i_cids=src_channel_ids,
dsts=dsts,
ctype=op.channel_type.value,
cnt=op.cnt(),
)
class _CopyConverter(_OpConverter):
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
src = op.src
dst = op.dst
return _JsonInstruction(
name=op.inst.value,
src=src.rank,
srcbuff=src.buffer.value,
srcoff=src.index,
dst=dst.rank,
dstbuff=dst.buffer.value,
dstoff=dst.index,
ctype=op.channel_type.value,
cnt=op.cnt(),
)
class _GroupLoadReduceStoreConverter(_OpConverter):
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
src = op.src
dst = op.dst
src_channel_ids = self.get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type)
dst_channel_ids = self.get_channel_ids(op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type)
return _JsonInstruction(
name=op.inst.value,
src=src.rank,
srcbuff=src.buffer.value,
srcoff=src.index,
dst=dst.rank,
dstbuff=dst.buffer.value,
dstoff=dst.index,
i_cids=src_channel_ids,
o_cids=dst_channel_ids,
ctype=op.channel_type.value,
cnt=op.cnt(),
)
_json_converter_map: Dict[Instruction, _OpConverter] = {
Instruction.signal: _SignalFlushConverter(),
Instruction.flush: _SignalFlushConverter(),
Instruction.wait: _WaitConverter(),
Instruction.read_reduce_copy: _ReadReduceCopyConverter(),
Instruction.read_reduce_copy_send: _ReadReduceCopySendConverter(),
Instruction.reduce_send: _ReduceSendConverter(),
Instruction.reduce_send_packet: _ReduceSendConverter(),
Instruction.reduce: _ReduceConverters(),
Instruction.reduce_packet: _ReduceConverters(),
Instruction.nop: _NopConverter(),
Instruction.barrier: _BarrierConverter(),
Instruction.put: _PutConverter(),
Instruction.put_packet: _PutConverter(),
Instruction.put_with_signal: _PutConverter(),
Instruction.put_with_signal_and_flush: _PutConverter(),
Instruction.get: _GetConverter(),
Instruction.copy: _CopyConverter(),
Instruction.copy_packet: _CopyConverter(),
Instruction.transform_to_packet: _CopyConverter(),
Instruction.group_load_reduce_store: _GroupLoadReduceStoreConverter(),
}
def _dump_to_json(program: Program):
gpus = []
def remove_empty_fields(d):
return {k: v for k, v in d.items() if v not in [None, "", [], {}]}
max_scratch = max(gpu.scratch_chunks for gpu in program.gpus)
max_input = max(gpu.input_chunks for gpu in program.gpus)
max_output = max(gpu.output_chunks for gpu in program.gpus)
for id, gpu in enumerate(program.gpus):
gpu_instance = {
"id": id,
"inputChunks": gpu.input_chunks,
"outputChunks": gpu.output_chunks,
"scratchChunks": gpu.scratch_chunks,
"chunkGroups": program.num_chunk_groups,
"threadblocks": [],
"channels": [],
}
for (srcBuffer, dstBuffer, type), channels in gpu.channels.items():
obj = {
"srcbuff": srcBuffer.value if hasattr(srcBuffer, "value") else srcBuffer,
"dstbuff": dstBuffer.value if hasattr(dstBuffer, "value") else dstBuffer,
"type": type.value,
"connectedTo": [ch[1] for ch in channels],
}
if type == ChannelType.nvls:
obj["connectedTo"] = [sorted(list(peers)) for peers in obj["connectedTo"]]
gpu_instance["channels"].append(obj)
gpu_instance["channels"] = list(filter(lambda x: x["type"] != "none", gpu_instance["channels"]))
gpu_instance["channels"] = sorted(gpu_instance["channels"], key=lambda x: (x["srcbuff"], x["dstbuff"]))
# render for GPU NVLS channels
for i, chan in enumerate(gpu_instance["channels"]):
if chan["type"] == "nvls":
buff = chan["srcbuff"]
buffer_size = (
max_input
if buff == Buffer.input.value
else max_output if buff == Buffer.output.value else max_scratch
)
gpu_instance["channels"][i] = {
"buff": chan["srcbuff"],
"type": chan["type"],
"rankGroups": [{"size": buffer_size, "ranks": ranks} for ranks in chan["connectedTo"]],
}
for tb in gpu.threadblocks:
if tb.id < 0:
continue
ops = []
tb_channels = []
tb_channel_dict = {}
for (srcBuffer, dstBuffer, type), channels in gpu.channels.items():
obj = {
"srcbuff": srcBuffer.value if hasattr(srcBuffer, "value") else srcBuffer,
"dstbuff": dstBuffer.value if hasattr(dstBuffer, "value") else dstBuffer,
"type": type.value,
"chanIds": [id for id, ele in enumerate(channels) if ele[0] == tb.id],
"connectedTo": [ele[1] for ele in channels if ele[0] == tb.id],
}
if len(obj["chanIds"]) > 0:
tb_channel_dict[(srcBuffer, dstBuffer, type)] = obj
tb_channels.append(obj)
tb_channels = filter(lambda x: x["type"] != "none", tb_channels)
tb_channels = sorted(tb_channels, key=lambda x: (x["srcbuff"], x["dstbuff"]))
for op in tb.ops:
if op.tb == -1:
continue
instr = _json_converter_map[op.inst].to_json(op, tb_channel_dict)
ops.append(remove_empty_fields(asdict(instr)))
threadblock = {
"id": tb.id,
"ops": ops,
"channels": list(
map(
lambda x: {"src": x["srcbuff"], "dst": x["dstbuff"], "ctype": x["type"], "cids": x["chanIds"]},
tb_channels,
)
),
}
gpu_instance["threadblocks"].append(threadblock)
gpus.append(gpu_instance)
obj = {
"name": program.name,
"collective": program.collective,
"protocol": program.protocol,
"inplace": program.inplace,
"gpus": gpus,
"num_threads_per_block": program.num_threads_per_block,
"use_double_scratch_buffer": program.use_double_scratch_buffer,
"min_message_size": program.min_message_size,
"max_message_size": program.max_message_size,
}
return json.dumps(obj, indent=2)

View File

@@ -0,0 +1,433 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from dataclasses import dataclass
from mscclpp.language.collectives import Collective
from mscclpp.language.buffer import *
from mscclpp.language.types import ChannelType, ChunkRef, ReplicationPolicy, Threadblock
from mscclpp.language.ir import *
from mscclpp.language.dag import DagOptimizer, DagLower, InstructionDAG
from mscclpp.language.rank import Rank
_current_program = None
def _curr():
global _current_program
if _current_program == None:
raise RuntimeError("No Program in context")
return _current_program
# For msccl++ program, we have one assumption that for channel can be identified by (send_buffer, recv_buffer, type, send_tb/recv_tb)
# which means the send_tb and recv_tb should be the same for a pair of signal and wait, also same for put/get operation.
# If one sender what to send data to peer want to use different tb in receiver side. We need to send to same tb in receiver side first,
# then performance a across tb sync. This is a limitation of current implementation.
class MSCCLPPProgram:
def __init__(
self,
name: str,
collective: Collective,
num_ranks: int,
instances: int,
protocol: str = "Simple",
instr_fusion: bool = True,
replication_policy: ReplicationPolicy = ReplicationPolicy.duplicated,
num_threads_per_block: int = 1024,
use_double_scratch_buffer: bool = False,
min_message_size: int = 0,
max_message_size: int = 2**64 - 1,
):
self.name = name
self.collective = collective
self.num_ranks = num_ranks
self.instances = instances
self.protocol = protocol
self.instr_fusion = instr_fusion
self.replication_policy = replication_policy
self.num_threads_per_block = num_threads_per_block
self.use_double_scratch_buffer = use_double_scratch_buffer
self.min_message_size = min_message_size
self.max_message_size = max_message_size
assert protocol == "Simple" or protocol == "LL", f"Given protocol: {protocol}. Must be either Simple, LL"
self.run_opt = True # Runs optimization passes
# Initialize the input buffers
self.buffers = collective.init_buffers()
self.instr_dag = InstructionDAG(self.num_ranks, self.buffers)
self.ranks = []
for r in range(self.num_ranks):
self.ranks.append(Rank(r))
for index, chunk in enumerate(self.buffers[r][Buffer.input]):
buffer, index = self.collective.get_buffer_index(r, Buffer.input, index)
ref = self.get_ref(r, buffer, index, 1)
# self.chunk_dag.init_chunk(chunk, ref)
self.instr_dag.add_start(r, buffer, index, ref)
def __enter__(self):
global _current_program
if _current_program != None:
raise RuntimeError("There is already a MSCCLPP Program in context")
_current_program = self
def __exit__(self, exc_type, exc_value, exc_traceback):
global _current_program
if _current_program != self:
raise RuntimeError("This program is not currently in context")
_current_program = None
def _convert_to_execution_plan(self):
ops = self.instr_dag.convert_set_list()
ops = sorted(ops, key=lambda x: x.step)
for op in ops:
rank = op.rank
tbid = op.tb
if tbid not in self.instr_dag.tbs[rank]:
self.instr_dag.tbs[rank][tbid] = Threadblock(id=tbid)
tb = self.instr_dag.tbs[rank][tbid]
tb.ops.append(op)
def get_rank_ref(self, rank):
return RankRef(rank, self)
# Tracks a send operation on the buffers
def apply_send(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size):
src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index)
dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index)
sb = self.buffers[src][src_buffer]
db = self.buffers[dst][dst_buffer]
for i in range(size):
db[dst_index + i] = sb[src_index + i]
# Tracks a reduce operation on the buffers
def apply_reduce(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size):
src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index)
dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index)
sb = self.buffers[src][src_buffer]
db = self.buffers[dst][dst_buffer]
for i in range(size):
reduce_chunk = db[dst_index + i]
sent_chunk = sb[src_index + i]
db[dst_index + i] = reduce_chunk.reduce(dst, sent_chunk)
def get_ref(self, rank, buffer, index, size):
buffer, index = self.collective.get_buffer_index(rank, buffer, index)
return Ref(rank, buffer, index, size, self)
def get_chunks(self, rank, buffer, index, size=1):
chunks = [None] * size
for i in range(0, size):
if self.buffers[rank][buffer] and index + i < len(self.buffers[rank][buffer]):
chunks[i] = self.buffers[rank][buffer][index + i]
else:
chunks[i] = None
return chunks
def check_buffer_exists(self, rank, name):
if name not in self.buffers[rank]:
self.buffers[rank][name] = BufferSlice(Buffer.scratch, name)
# Checks that all chunks that should be on each rank
# are present in the output buffer.
def check(self):
return self.collective.check(self)
# Lower program to MSCCLPP
def lower(self):
self._convert_to_execution_plan()
self.instr_dag.complete_channels()
dag_optimizer = DagOptimizer(self.instr_dag)
dag_optimizer.remove_redundant_signal_wait()
if self.instr_fusion:
dag_optimizer.fuse_instructions()
dag_lower = DagLower(self.instr_dag)
gpu_prgms = dag_lower.lower(self.instances, self.replication_policy)
program = Program(
self.name,
self.collective.name,
self.collective.inplace,
self.protocol,
gpu_prgms,
self.collective.num_chunk_groups * self.instances,
self.num_threads_per_block,
self.use_double_scratch_buffer,
self.min_message_size,
self.max_message_size,
)
for gpu in program.gpus:
gpu.input_chunks = len(self.buffers[gpu.rank][Buffer.input]) * self.instances
if not self.collective.inplace:
gpu.output_chunks = len(self.buffers[gpu.rank][Buffer.output]) * self.instances
return program
def generate_json(self):
return ir_to_json(self.lower())
def Json():
print(_curr().generate_json())
@dataclass
class RankRef:
rank: int
prog: MSCCLPPProgram
def _get_barrier_id(self, tb_list) -> int:
return self.prog.ranks[self.rank].get_barrier_id(tb_list)
def barrier(self, tb_list):
barrier_id = self._get_barrier_id(tb_list)
return self.prog.instr_dag.add_barrier(self.rank, tb_list, barrier_id)
@dataclass
class Ref(ChunkRef):
prog: MSCCLPPProgram
def __repr__(self):
return f"Ref(Buffer:{self.buffer}, Index:{self.index}, Size:{self.size}, Rank:{self.rank})"
def _end(self):
return self.index + self.size
def _get_chunk(self, index):
return self.prog.buffers[self.rank][self.buffer][index]
def split(self, num):
assert self.size % num == 0, f"Trying to split a chunk of {self.size} elements into {num} parts"
chunks = [None] * num
size = self.size // num
for i in range(num):
index = self.index + i * size
chunks[i] = self.prog.get_ref(self.rank, self.buffer, index, size)
return chunks
def group(self, other):
assert self.rank == other.rank, f"Trying to concatenate chunks on ranks {self.rank} and {other.rank}"
assert self.buffer == other.buffer, f"Trying to concatenate chunks in {self.buffer} and {other.buffer}"
if self.index < other.index:
first = self
second = other
else:
first = other
second = self
end = max(first._end(), second._end())
return Ref(self.rank, self.buffer, first.index, end - first.index, self.prog)
def _get_buffer_index(self, remote_rank, buffer, index):
if index == -1 and buffer == None:
return self.buffer, self.index
elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output:
return buffer, self.prog.buffers[remote_rank][buffer].instance_size()
return buffer, index
def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, use_packet=False):
self.prog.check_buffer_exists(dst, buffer)
assert self.rank != dst, "Cannot put to the same rank"
buffer, index = self._get_buffer_index(dst, buffer, index)
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size)
if use_packet:
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, True)
self.prog.instr_dag.add_signal(self.rank, self, dst_chunkref, -1, ChannelType.none)
self.prog.instr_dag.add_wait(dst, dst_chunkref, self, -1, ChannelType.none)
else:
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type)
return dst_chunkref
def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm):
return self._put(dst, buffer, index, sendtb, chan_type)
def put_packet(
self,
dst,
buffer=None,
index=-1,
sendtb=-1,
chan_type=ChannelType.sm,
temp_buffer=None,
temp_buffer_index=-1,
):
chunk_ref = self
if chan_type == ChannelType.proxy:
assert temp_buffer is not None, "Need to specify a temporary buffer for proxy channels"
chunk_ref = self._copy(
self.rank, temp_buffer, temp_buffer_index, sendtb, trans_from_packet=False, trans_to_packet=True
)
return chunk_ref._put(dst, buffer, index, sendtb, chan_type, True)
def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm):
self.prog.check_buffer_exists(src, buffer)
sender = src
receiver = self.rank
assert sender != receiver, "Cannot get from the same rank"
buffer, index = self._get_buffer_index(src, buffer, index)
src_chunkref = self.prog.get_ref(src, buffer, index, self.size)
self.prog.apply_send(src, buffer, index, self.rank, self.buffer, self.index, self.size)
self.prog.instr_dag.add_get(receiver, src_chunkref, self, recvtb, chan_type)
# for signal and wait, currently we assuem the pair will use the same tb index. In future we need
# to infer the tb index from the instruction DAG Add a channel is define as (send_tb, src_buffer, recv_tb, dst_buffer, type).
# Then we can use DAG info to reduce the number of channels.
def signal(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm):
sender = self.rank
receiver = dst
assert sender != receiver, "Cannot signal to the same rank"
buffer, index = self._get_buffer_index(dst, buffer, index)
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
self.prog.instr_dag.add_signal(sender, self, dst_chunkref, sendtb, chan_type)
# only proxy channel need to use this function
def flush(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.proxy):
assert chan_type == ChannelType.proxy, "Only proxy channel can use flush"
sender = self.rank
receiver = dst
assert sender != receiver, "Cannot flush to the same rank"
buffer, index = self._get_buffer_index(dst, buffer, index)
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
self.prog.instr_dag.add_flush(sender, self, dst_chunkref, sendtb)
def wait(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm):
sender = src
receiver = self.rank
assert sender != receiver, "Cannot wait on the same rank"
buffer, index = self._get_buffer_index(src, buffer, index)
src_chunkref = self.prog.get_ref(src, buffer, index, self.size)
self.prog.instr_dag.add_wait(receiver, self, src_chunkref, recvtb, chan_type)
def _copy(self, dst, buffer=None, index=-1, sendtb=-1, trans_from_packet=False, trans_to_packet=False):
self.prog.check_buffer_exists(dst, buffer)
buffer, index = self._get_buffer_index(dst, buffer, index)
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
# Check if we are copying the chunk to the same index (easy mistake when we are using inplace)
if dst_chunkref == self:
return
self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size)
assert self.rank == dst, "Chunk copy only supports intra-rank communication"
self.prog.instr_dag.add_copy(self.rank, self, dst_chunkref, sendtb, trans_from_packet, trans_to_packet)
return dst_chunkref
# Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index)
def copy(self, dst, buffer=None, index=-1, sendtb=-1):
return self._copy(dst, buffer, index, sendtb)
def copy_packet(self, dst, buffer=None, index=-1, sendtb=-1):
return self._copy(dst, buffer, index, sendtb, trans_from_packet=True, trans_to_packet=False)
def _reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.sm, use_packet=False):
dst = self.rank
src = other_chunkref.rank
self.prog.apply_reduce(
src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size
)
if use_packet:
assert src == dst, "Packet reduce only supports intra-rank communication"
if src != dst:
self.prog.instr_dag.add_read_reduce(dst, other_chunkref, self, recvtb, channel_type)
else:
self.prog.instr_dag.add_reduce(src, other_chunkref, self, recvtb, use_packet)
return self
# Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref
def reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.sm):
return self._reduce(other_chunkref, recvtb, channel_type)
# Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref
def reduce_packet(self, other_chunkref, recvtb=-1):
return self._reduce(other_chunkref, recvtb, use_packet=True)
# """
# Group operations. These operations are used to perform collective operations across multiple chunks.
# For now, all chunks must has the same buffer type and offset.
# """
# Reads the chunk(s) referenced by other_chunkref and reduce into the chunk referenced by this chunkref
def group_load_reduce(self, other_chunkrefs: list, recvtb=-1, chan_type=ChannelType.nvls):
assert (
len(other_chunkrefs) > 0 and chan_type == ChannelType.nvls
), "Group load reduce only supports nvls channel"
nranks_per_node = self.prog.collective.num_ranks_per_node
for other_chunkref in other_chunkrefs:
assert (
self.rank // nranks_per_node == other_chunkref.rank // nranks_per_node
), "Group load reduce only supports chunks on the same node"
assert self.buffer == other_chunkref.buffer, "Group load reduce only supports chunks with the same buffer"
assert self.index == other_chunkref.index, "Group load reduce only supports chunks with the same index"
src_chunkref = other_chunkref
self.prog.apply_reduce(
src_chunkref.rank,
src_chunkref.buffer,
src_chunkref.index,
self.rank,
self.buffer,
self.index,
self.size,
)
self.prog.instr_dag.add_group_load_reduce(self.rank, other_chunkrefs, self, recvtb, chan_type)
return self
# Copies the chunk(s) referenced by this chunkref onto other_chunkrefs
def group_store(self, dsts: list, index=-1, buffer=None, sendtb=-1, chan_type=ChannelType.nvls):
for dst in dsts:
self.prog.check_buffer_exists(dst, buffer)
assert index == -1 or self.index == index, "Group store only supports chunks with the same index"
assert chan_type == ChannelType.nvls, "Group store only supports nvls channel"
other_chunkrefs = []
nrank_per_node = self.prog.collective.num_ranks_per_node
for dst in dsts:
# Direct linked
buffer, index = self._get_buffer_index(dst, buffer, index)
assert self.buffer == buffer, "Group store only supports chunks with the same buffer"
assert (
self.rank // nrank_per_node == dst // nrank_per_node
), "Group store only supports chunks on the same node"
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size)
other_chunkrefs.append(dst_chunkref)
# add new op here
self.prog.instr_dag.add_group_store(self.rank, self, other_chunkrefs, sendtb, chan_type)
def get_origin_index(self, index=0):
return self._get_chunk(index + self.index).origin_index
def get_origin_rank(self, index=0):
return self._get_chunk(index + self.index).origin_rank
def get_dst_index(self, index=0):
return self._get_chunk(index + self.index).dst_index
def get_dst_rank(self, index=0):
return self._get_chunk(index + self.index).dst_rank
def print_chunk_info(self, index=0):
print(self._get_chunk(index + self.index))
def chunk(rank, buffer, index, size=1) -> Ref:
if _curr().buffers[rank][buffer][index] is None:
return None
return _curr().get_ref(rank, buffer, index, size)
def rank(rank) -> RankRef:
return _curr().get_rank_ref(rank)
def Check():
return _curr().check()

View File

@@ -0,0 +1,33 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from dataclasses import dataclass, field
from typing import Dict
class BarrierInfo:
def __init__(self, tb_list):
self.tb_list = tb_list
def __eq__(self, other):
return self.tb_list == other.tb_list
def __hash__(self):
return hash(tuple(self.tb_list))
@dataclass
class Rank:
rank_id: int
current_max_barrier_id: int = 0
current_barriers: Dict[BarrierInfo, int] = field(default_factory=dict)
def get_barrier_id(self, tb_list):
barrier_info = BarrierInfo(tb_list)
if barrier_info in self.current_barriers:
return self.current_barriers[barrier_info]
else:
self.current_barriers[barrier_info] = self.current_max_barrier_id
barrier_id = self.current_max_barrier_id
self.current_max_barrier_id += 1
return barrier_id

View File

@@ -0,0 +1,173 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from dataclasses import dataclass, field
from enum import Enum
from typing import Union, List
from mscclpp.language.buffer import Buffer
@dataclass
class Gpu:
rank: int
threadblocks: list = field(default_factory=list)
# From ncclize
precopies: list = field(default_factory=list)
postcopies: list = field(default_factory=list)
inputs: dict = field(default_factory=dict)
outputs: dict = field(default_factory=dict)
input_chunks: int = 0
output_chunks: int = 0
scratch_chunks: int = 0
scratch: dict = field(default_factory=dict)
channels: dict = field(default_factory=dict)
def scratch_size(self):
return max((idx for addr, idx in self.scratch.items()), default=-1) + 1
@dataclass
class Program:
name: str
collective: str
inplace: bool
protocol: str
gpus: List[Gpu] = field(default_factory=list)
num_chunk_groups: int = 1
num_threads_per_block: int = 1024
use_double_scratch_buffer: bool = False
min_message_size: int = 0
max_message_size: int = 2**64 - 1
@dataclass
class Threadblock:
channel: int = -1
send: int = -1
recv: int = -1
ops: list = field(default_factory=list)
rbid: int = -1 # threadblock id of the receiver
id: int = -1
channels: list = field(default_factory=list)
def __eq__(self, other):
return self is other
def __hash__(self):
return id(self)
class ReplicationPolicy(Enum):
# this means each instance deal with the different chunk
# Chunk A, Chunk B -> Chunk A0, Chunk B0, Chunk A1, Chunk B1
duplicated = "duplicated"
# this means each instance deal with the different chunk in interleaved way
# Chunk A, Chunk B -> Chunk A0, Chunk A1, Chunk B0, Chunk B1
interleaved = "interleaved"
# this means pack multi instrances to deal with the same chunk and share the channels
packed = "packed"
def __str__(self):
return self.value
class Instruction(Enum):
start = "start"
nop = "nop"
read_reduce_copy = "rrc"
read_reduce_copy_send = "rrcs"
reduce_send = "rs"
copy = "copy"
reduce = "reduce"
copy_packet = "cpkt"
transform_to_packet = "tpkt"
reduce_send_packet = "rspkt"
reduce_packet = "rpkt"
put = "put"
put_packet = "ppkt"
put_with_signal = "pws"
put_with_signal_and_flush = "pwsf"
get = "get"
wait = "wait"
signal = "signal"
flush = "flush"
barrier = "barrier"
group_store = "gstore"
group_load_reduce = "glre"
group_load_reduce_store = "glres"
def __str__(self):
return self.value
@dataclass
class ChunkRef:
rank: int
buffer: Buffer
index: int
size: int
def __hash__(self):
return hash((self.rank, self.buffer, self.index, self.size))
class ChannelType(Enum):
proxy = "proxy"
sm = "sm"
none = "none"
nvls = "nvls"
def __str__(self):
return self.value
@dataclass(frozen=True)
class Channel:
srcBuffer: Buffer
dstBuffer: Buffer
type: ChannelType
connected_to: Union[int, List[int]]
def __hash__(self):
# Ensure connected_to is converted to a tuple if it's a list
connected_to_hashable = tuple(self.connected_to) if isinstance(self.connected_to, list) else self.connected_to
return hash((self.srcBuffer, self.dstBuffer, self.type, connected_to_hashable))
@dataclass
class Op:
inst: Instruction
rank: int
src: ChunkRef
dst: ChunkRef
depends: list = field(default_factory=list)
step: int = -1 # Step in the TB
tb: int = -1 # TB this op is assigned to
prev: list = field(default_factory=list) # List of instructions that happen before
next: list = field(default_factory=list) # List of instructions that happen after
channel: int = -1
channel_type: ChannelType = ChannelType.none
srcs: list = field(default_factory=list)
dsts: list = field(default_factory=list)
extra: dict = field(default_factory=dict)
def cnt(self):
if self.src:
if self.dst:
assert self.src.size == self.dst.size
return self.src.size
elif self.dst:
return self.dst.size
else:
return 0
def __eq__(self, other):
return self is other
def __hash__(self):
return id(self)
def __repr__(self):
return f"Op({self.inst}, {self.rank}, {self.src}, {self.dst}, step:{self.step}, tb:{self.tb})"

View File

@@ -0,0 +1,94 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from mscclpp.language.types import Op
def remove_op(op: Op):
for p in op.prev:
p.next.remove(op)
p.next += op.next
p.next = list(set(p.next))
for n in op.next:
n.prev.remove(op)
n.prev = op.prev.union(n.prev)
op.next = []
op.prev = []
def merge_op(op: Op, other_op: Op):
if other_op in op.next:
op.next.remove(other_op)
other_op.prev.remove(op)
for p in other_op.prev:
p.next.remove(other_op)
p.next.append(op)
for n in other_op.next:
n.prev.remove(other_op)
n.prev.add(op)
op.prev = op.prev.union(other_op.prev)
op.next = list(set(op.next + other_op.next))
def circular_dep_after_merge(op: Op, other_op: Op):
root = set([op, other_op])
frontier = set(op.next)
if other_op in frontier:
frontier.remove(other_op)
frontier = list(frontier.union(other_op.next))
while len(frontier) > 0:
current = frontier[0]
for n in current.next:
# The root node will be visited again if there is a circular dependency
if n in root:
return True
frontier.append(n)
frontier = frontier[1:]
def all_prevs_visited_after_merge(op: Op, other_op: Op):
"""
For case: op2.prev = [op1, op3]. op1.next = [op2]. op3.next = [op2]. And op1 and op2 are satisfied to merge.
We only apply the merge if all previous ops of op2 are visited. (op1 is the last previous op of op2).
"""
step = op.step
for prev in other_op.prev:
if prev.step > step:
return False
return True
def same_tb(op1: Op, op2: Op):
return op1.tb == op2.tb and op1.channel == op2.channel
def same_count(op1: Op, op2: Op):
return op1.cnt() == op2.cnt()
def same_buf_dst(op1: Op, op2: Op):
return op1.dst.buffer == op2.dst.buffer and op1.dst.index == op2.dst.index
def same_src_dst_buffer_type(op1: Op, op2: Op):
return op1.src.buffer == op2.src.buffer and op1.dst.buffer == op2.dst.buffer
def buf_dst_src_match(op1: Op, op2: Op):
return op1.dst.buffer == op2.src.buffer and op1.dst.index == op2.src.index
def same_buf_src(op1: Op, op2: Op):
return op1.src.buffer == op2.src.buffer and op1.src.index == op2.src.index
def same_chan_type(op1: Op, op2: Op):
return op1.channel_type == op2.channel_type
def same_tb(op1: Op, op2: Op):
return op1.tb == op2.tb

View File

@@ -0,0 +1,34 @@
[
{
"filename": "allgather_barrier.py",
"args": ["8", "8"]
},
{
"filename": "allreduce_allpairs_packet.py",
"args": ["8", "8"]
},
{
"filename": "allreduce_allpairs_get.py",
"args": ["8", "8"]
},
{
"filename": "allreduce_allpairs.py",
"args": ["8", "8"]
},
{
"filename": "allreduce_ring.py",
"args": ["8", "8"]
},
{
"filename": "send_recv_packet.py",
"args": ["2"]
},
{
"filename": "send_recv_proxy.py",
"args": ["2"]
},
{
"filename": "allreduce_nvls.py",
"args": ["8", "2"]
}
]

View File

@@ -0,0 +1,47 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import json
from pathlib import Path
import subprocess
def run_examples(input_folder, configs, output_folder):
for config in configs:
file_name = config["filename"]
args = config["args"]
input_file_path = Path(input_folder) / file_name
# Strip the ".py" from the filename and add ".output"
base_file_name = file_name[:-3] if file_name.endswith(".py") else file_name
base_file_name = base_file_name.replace("/", "_")
output_file_path = Path(output_folder) / f"{base_file_name}.output"
# Construct the command to run the Python script
command = ["python3", str(input_file_path)] + args
# Run the command and capture output
with open(output_file_path, "w") as output_file:
result = subprocess.run(command, stdout=output_file, stderr=subprocess.STDOUT, text=True)
# Optional: Check the return code to handle errors
if result.returncode != 0:
print(f"Error running {file_name}. See {output_file_path} for details.")
def main(input_folder, config_path, output_folder):
with open(config_path, "r") as f:
config = json.load(f)
Path(output_folder).mkdir(parents=True, exist_ok=True)
run_examples(input_folder, config, output_folder)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process files according to a configuration and save the results.")
parser.add_argument("input_folder", type=str, help="Path to the folder containing the input files.")
parser.add_argument("config", type=str, help="Path to the configuration file.")
parser.add_argument("output_folder", type=str, help="Path to the folder where the processed files will be saved.")
args = parser.parse_args()
main(args.input_folder, args.config, args.output_folder)