mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
Merge mscclpp-lang to mscclpp project (#442)
First step to merge msccl-tools into mscclpp repo. In this step will move all msccl related code, pass the current tests and do some necessary refactor. Add `mscclpp.language` module Add `_InstructionOptimizer` and `DagOptimizer` class to optimize the dag Add `DagLower` to lower dag to intermediate representation Add documents for mscclpp.language Remove msccl related code
This commit is contained in:
55
python/examples/allgather_barrier.py
Normal file
55
python/examples/allgather_barrier.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import argparse
|
||||
from mscclpp.language import *
|
||||
from mscclpp.language.buffer import Buffer
|
||||
from mscclpp.language.collectives import AllGather
|
||||
from mscclpp.language.types import ChannelType, ReplicationPolicy
|
||||
|
||||
|
||||
def allgather_test(gpus, instances):
|
||||
"""
|
||||
Demonstrates how to use barrier in the MSCCL++ DSL with an allgather collective.
|
||||
This example uses an allpairs algorithm for the allgather operation.
|
||||
Steps:
|
||||
1. Each rank sends a chunk to all other ranks' output buffers and copies the chunk to its own output buffer.
|
||||
2. A barrier is called to synchronize the send and copy operations, and signal peers that the data has been sent.
|
||||
3. Wait for all the chunks from other ranks to be received.
|
||||
"""
|
||||
size = gpus
|
||||
collective = AllGather(size, 1, False)
|
||||
with MSCCLPPProgram(
|
||||
"allgather_with_barrier",
|
||||
collective,
|
||||
size,
|
||||
instances,
|
||||
protocol="Simple",
|
||||
replication_policy=ReplicationPolicy.interleaved,
|
||||
):
|
||||
for n in range(gpus):
|
||||
c = chunk(n, Buffer.input, 0, 1)
|
||||
for peer in range(gpus):
|
||||
if n != peer:
|
||||
c.put(peer, Buffer.output, n, sendtb=peer, chan_type=ChannelType.sm)
|
||||
else:
|
||||
c.copy(n, Buffer.output, n, sendtb=peer)
|
||||
# explicit barrier
|
||||
r = rank(n)
|
||||
r.barrier(tb_list=list(range(gpus)))
|
||||
for peer in range(gpus):
|
||||
if n != peer:
|
||||
c.signal(peer, Buffer.output, n, sendtb=peer, chan_type=ChannelType.sm)
|
||||
|
||||
for n in range(gpus):
|
||||
for peer in range(gpus):
|
||||
c = chunk(n, Buffer.output, peer, 1)
|
||||
if n != peer:
|
||||
c.wait(peer, Buffer.input, peer, recvtb=peer, chan_type=ChannelType.sm)
|
||||
|
||||
Json()
|
||||
Check()
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("num_gpus", type=int, help="number of gpus")
|
||||
parser.add_argument("instances", type=int, help="number of instances")
|
||||
args = parser.parse_args()
|
||||
allgather_test(args.num_gpus, args.instances)
|
||||
65
python/examples/allreduce_allpairs.py
Normal file
65
python/examples/allreduce_allpairs.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
from mscclpp.language import *
|
||||
from mscclpp.language.collectives import AllReduce
|
||||
from mscclpp.language.buffer import Buffer
|
||||
|
||||
|
||||
def allreduce_allpairs(gpus, instances, protocol):
|
||||
"""
|
||||
Demonstrate allreduce with all pairs algorithm using put semantics.
|
||||
Steps:
|
||||
1. Sync all ranks to ensure the data is ready.
|
||||
2. Each rank reads chunks from all peers and reduces the data.
|
||||
3. Put the reduced data to all peers.
|
||||
4. Sync all ranks to ensure the data is received.
|
||||
"""
|
||||
size = gpus
|
||||
chunksperloop = gpus * gpus
|
||||
collective = AllReduce(size, chunksperloop, True)
|
||||
with MSCCLPPProgram("allreduce_pairs", collective, size, instances, protocol=protocol):
|
||||
for rank in range(size):
|
||||
for tb in range(size):
|
||||
index = rank * size
|
||||
c = chunk(rank, Buffer.input, index + tb)
|
||||
# step1 make sure the data is ready
|
||||
for nghr in range(size):
|
||||
peer_index = nghr * size
|
||||
if rank != nghr:
|
||||
# signal peer the buffer is ready
|
||||
c_peer = chunk(rank, Buffer.input, peer_index + tb)
|
||||
c_peer.signal(nghr, Buffer.input, peer_index + tb, sendtb=tb)
|
||||
for nghr in range(size):
|
||||
if rank != nghr:
|
||||
c.wait(nghr, Buffer.input, index + tb, recvtb=tb)
|
||||
# step2 reduce the chunks and send to peers
|
||||
for nghr in range(size):
|
||||
if rank != nghr:
|
||||
c.reduce(chunk(nghr, Buffer.input, index + tb), recvtb=tb)
|
||||
for nghr in range(size):
|
||||
if rank != nghr:
|
||||
c.put(nghr, Buffer.input, index + tb, sendtb=tb)
|
||||
# step3 signal the peers buffer is ready
|
||||
for nghr in range(size):
|
||||
if rank != nghr:
|
||||
c.signal(nghr, Buffer.input, index + tb, sendtb=tb)
|
||||
for nghr in range(size):
|
||||
if rank != nghr:
|
||||
peer_index = nghr * size
|
||||
c_peer = chunk(rank, Buffer.input, peer_index + tb)
|
||||
c_peer.wait(nghr, Buffer.input, peer_index + tb, recvtb=tb)
|
||||
|
||||
Json()
|
||||
Check()
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("num_gpus", type=int, help="number of gpus")
|
||||
parser.add_argument("instances", type=int, help="number of instances")
|
||||
parser.add_argument("--protocol", type=str, default="Simple", choices=["Simple"], help="Protocol")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
allreduce_allpairs(args.num_gpus, args.instances, args.protocol)
|
||||
78
python/examples/allreduce_allpairs_get.py
Normal file
78
python/examples/allreduce_allpairs_get.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
from mscclpp.language import *
|
||||
from mscclpp.language.collectives import AllReduce
|
||||
from mscclpp.language.buffer import Buffer
|
||||
|
||||
|
||||
def allreduce_allpairs(gpus, instances):
|
||||
"""
|
||||
AllReduce with all pairs algorithm using get semantics.
|
||||
Steps:
|
||||
1. Sync all ranks to ensure the data is ready.
|
||||
2. Each rank read chunks from all peers and reduces the data.
|
||||
3. Signal all ranks to notify that the data is ready.
|
||||
4. Wait for all chunks to be ready, then retrieve the chunks from all peers.
|
||||
"""
|
||||
size = gpus
|
||||
chunksperloop = gpus * gpus
|
||||
collective = AllReduce(size, chunksperloop, True)
|
||||
with MSCCLPPProgram(
|
||||
"allreduce_pairs",
|
||||
collective,
|
||||
size,
|
||||
instances,
|
||||
protocol="Simple",
|
||||
):
|
||||
|
||||
# Each rank sends the nth chunk to the nth rank into scratch space
|
||||
for rank in range(size):
|
||||
for tb in range(size):
|
||||
index = rank * size
|
||||
c = chunk(rank, Buffer.input, index + tb)
|
||||
# make sure the data is ready
|
||||
for nghr in range(size):
|
||||
peer_index = nghr * size
|
||||
if rank != nghr:
|
||||
c_peer = chunk(rank, Buffer.input, peer_index + tb)
|
||||
c_peer.signal(nghr, Buffer.input, peer_index + tb, sendtb=tb)
|
||||
for nghr in range(size):
|
||||
if rank != nghr:
|
||||
c.wait(nghr, Buffer.input, index + tb, recvtb=tb)
|
||||
# reduce the chunks
|
||||
for i in range(size):
|
||||
nghr = (rank + i) % size
|
||||
if rank != nghr:
|
||||
c.reduce(chunk(nghr, Buffer.input, index + tb), recvtb=tb)
|
||||
for nghr in range(size):
|
||||
if rank != nghr:
|
||||
c.signal(nghr, Buffer.input, index + tb, sendtb=tb)
|
||||
|
||||
# wait for all the chunks is ready, then get the chunks
|
||||
for rank in range(size):
|
||||
for tb in range(size):
|
||||
for nghr in range(size):
|
||||
if rank != nghr:
|
||||
index = nghr * size
|
||||
c = chunk(rank, Buffer.input, index + tb)
|
||||
c.wait(nghr, Buffer.input, index + tb, recvtb=tb)
|
||||
for i in range(size):
|
||||
nghr = (rank + i) % size
|
||||
index = nghr * size
|
||||
if rank != nghr:
|
||||
c = chunk(rank, Buffer.input, index + tb)
|
||||
c.get(nghr, Buffer.input, index + tb, recvtb=tb)
|
||||
|
||||
Json()
|
||||
Check()
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("num_gpus", type=int, help="number of gpus")
|
||||
parser.add_argument("instances", type=int, help="number of instances")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
allreduce_allpairs(args.num_gpus, args.instances)
|
||||
69
python/examples/allreduce_allpairs_packet.py
Normal file
69
python/examples/allreduce_allpairs_packet.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
from mscclpp.language import *
|
||||
from mscclpp.language.collectives import AllReduce
|
||||
from mscclpp.language.buffer import Buffer
|
||||
|
||||
|
||||
def allreduce_allpairs(gpus, instances):
|
||||
"""
|
||||
AllReduce with all pairs algorithm using packets format.
|
||||
Steps:
|
||||
1. Each rank sends its nth chunk to the nth rank's scratch space.
|
||||
2. Each rank performs a local reduction on its nth chunk using data from all other ranks' scratch spaces.
|
||||
3. Each rank sends the reduced data to all other ranks' scratch spaces.
|
||||
4. Each rank retrieves the final reduced result from the scratch space.
|
||||
"""
|
||||
size = gpus
|
||||
chunksperloop = gpus * gpus
|
||||
collective = AllReduce(size, chunksperloop, True)
|
||||
with MSCCLPPProgram(
|
||||
"allreduce_packets",
|
||||
collective,
|
||||
size,
|
||||
instances,
|
||||
protocol="LL",
|
||||
use_double_scratch_buffer=True,
|
||||
):
|
||||
# Each rank sends the nth chunk to the nth rank into scratch space
|
||||
for r1 in range(size):
|
||||
for tb in range(size):
|
||||
if tb == r1:
|
||||
continue
|
||||
remote_rank = tb
|
||||
index = remote_rank * size
|
||||
c = chunk(r1, Buffer.input, index, size)
|
||||
c.put_packet(remote_rank, "scratch", index=r1 * size, sendtb=tb)
|
||||
|
||||
# Each rank performs a local reduction on the nth chunk
|
||||
# Utilize 8 threadblocks for this reduction for better parallelism
|
||||
for r in range(size):
|
||||
for index in range(size):
|
||||
c = chunk(r, Buffer.input, r * size + index)
|
||||
for peer in range(size):
|
||||
if peer != r:
|
||||
c.reduce_packet(chunk(r, "scratch", peer * size + index), recvtb=index)
|
||||
for peer in range(size):
|
||||
if peer != r:
|
||||
c.put_packet(peer, "scratch", (size * size) + r * size + index, sendtb=index)
|
||||
|
||||
# Each rank get final result from scratch space
|
||||
for r in range(size):
|
||||
for peer in range(size):
|
||||
if peer != r:
|
||||
c = chunk(r, "scratch", size * size + peer * size, size)
|
||||
c.copy_packet(r, Buffer.input, peer * size, sendtb=peer)
|
||||
|
||||
Json()
|
||||
Check()
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("num_gpus", type=int, help="number of gpus")
|
||||
parser.add_argument("instances", type=int, help="number of instances")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
allreduce_allpairs(args.num_gpus, args.instances)
|
||||
55
python/examples/allreduce_nvls.py
Normal file
55
python/examples/allreduce_nvls.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
from mscclpp.language import *
|
||||
from mscclpp.language.collectives import AllReduce
|
||||
from mscclpp.language.buffer import Buffer
|
||||
|
||||
|
||||
def allreduce_nvls(gpus, instances):
|
||||
"""
|
||||
Allreduce via NVLS channel
|
||||
Steps:
|
||||
1. Sync all the ranks to make sure the data is ready.
|
||||
2. Call group_load_reduce to reduce the data.
|
||||
3. Call group_store to propagate the data to all the ranks.
|
||||
"""
|
||||
size = gpus
|
||||
chunksperloop = gpus
|
||||
collective = AllReduce(size, chunksperloop, True)
|
||||
with MSCCLPPProgram(
|
||||
"allreduce_nvls",
|
||||
collective,
|
||||
size,
|
||||
instances,
|
||||
):
|
||||
# Each rank sends the nth chunk to the nth rank into scratch space
|
||||
for rank in range(size):
|
||||
index = rank
|
||||
c = chunk(rank, Buffer.input, index)
|
||||
reduce_chunks = []
|
||||
# make sure the data is ready
|
||||
for nghr in range(size):
|
||||
if rank != nghr:
|
||||
c_peer = chunk(nghr, Buffer.input, index)
|
||||
reduce_chunks.append(c_peer)
|
||||
c.signal(nghr, Buffer.input, index, sendtb=0)
|
||||
for nghr in range(size):
|
||||
if rank != nghr:
|
||||
c.wait(nghr, Buffer.input, index, recvtb=0)
|
||||
c = c.group_load_reduce(reduce_chunks, recvtb=0)
|
||||
ngbrs = [nghr for nghr in range(size) if nghr != rank]
|
||||
c.group_store(ngbrs, sendtb=0)
|
||||
|
||||
Json()
|
||||
Check()
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("num_gpus", type=int, help="number of gpus")
|
||||
parser.add_argument("instances", type=int, help="number of instances")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
allreduce_nvls(args.num_gpus, args.instances)
|
||||
59
python/examples/allreduce_ring.py
Normal file
59
python/examples/allreduce_ring.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
from mscclpp.language import *
|
||||
from mscclpp.language.collectives import AllReduce
|
||||
from mscclpp.language.buffer import Buffer
|
||||
|
||||
|
||||
def allreduce_ring(size, instances):
|
||||
"""
|
||||
Implements a ring based allreduce.
|
||||
Steps:
|
||||
1. Send signal to next rank and wait for signal from previous rank. Make sure the data is ready in previous rank.
|
||||
2. Reduce the data and send to next rank.
|
||||
3. After all the data is reduced, propagate the data to all the ranks.
|
||||
"""
|
||||
collective = AllReduce(size, size, True)
|
||||
with MSCCLPPProgram(
|
||||
f"allreduce_ring",
|
||||
collective,
|
||||
size,
|
||||
instances,
|
||||
protocol="Simple",
|
||||
):
|
||||
# Reduce ring
|
||||
for step in range(0, size - 1):
|
||||
for index in range(0, size):
|
||||
rank = (index + step) % size
|
||||
next_rank = (index + step + 1) % size
|
||||
c = chunk(rank, Buffer.input, index)
|
||||
c.signal(next_rank, Buffer.input, index, 0)
|
||||
prev_rank = (index + step - 1) % size
|
||||
c = chunk(rank, Buffer.input, (index + size - 1) % size)
|
||||
c.wait(prev_rank, Buffer.input, (index + size - 1) % size, 0)
|
||||
c.reduce(chunk(prev_rank, Buffer.input, (index + size - 1) % size), recvtb=0)
|
||||
|
||||
# Propagate ring
|
||||
for step in range(-1, size - 2):
|
||||
for index in range(0, size):
|
||||
rank = (index + step) % size
|
||||
c = chunk(rank, Buffer.input, index)
|
||||
next_rank = (index + step + 1) % size
|
||||
c.put(next_rank, Buffer.input, index, sendtb=0)
|
||||
c.signal(next_rank, Buffer.input, index, 0)
|
||||
prev_rank = (index + step - 1) % size
|
||||
c = chunk(rank, Buffer.input, (index + size - 1) % size)
|
||||
c.wait(prev_rank, Buffer.input, (index + size - 1) % size, 0)
|
||||
|
||||
Json()
|
||||
Check()
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("num_gpus", type=int, help="number of gpus")
|
||||
parser.add_argument("instances", type=int, help="number of instances")
|
||||
args = parser.parse_args()
|
||||
|
||||
allreduce_ring(args.num_gpus, args.instances)
|
||||
57
python/examples/send_recv_packet.py
Normal file
57
python/examples/send_recv_packet.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
from mscclpp.language import *
|
||||
from mscclpp.language.collectives import SendRecv
|
||||
from mscclpp.language.buffer import Buffer
|
||||
from mscclpp.language.types import ChannelType
|
||||
|
||||
|
||||
def send_recv(instances):
|
||||
"""
|
||||
Send and receive data between two ranks using proxy channels, with LL protocol and double scratch buffer.
|
||||
Steps:
|
||||
1. Each rank sends a chunk to every other rank's scratch buffer with packet format via proxy channel.
|
||||
2. Wait for the data to be received, then copy it to the output buffer.
|
||||
"""
|
||||
size = 2
|
||||
chunksperloop = 1
|
||||
collective = SendRecv(size, chunksperloop, False)
|
||||
with MSCCLPPProgram(
|
||||
"send_recv",
|
||||
collective,
|
||||
size,
|
||||
instances,
|
||||
protocol="LL",
|
||||
use_double_scratch_buffer=True,
|
||||
):
|
||||
for r in range(size):
|
||||
for nghr in range(size):
|
||||
if nghr == r:
|
||||
continue
|
||||
c = chunk(r, Buffer.input, 0)
|
||||
c.put_packet(
|
||||
nghr,
|
||||
"scratch",
|
||||
1,
|
||||
sendtb=0,
|
||||
chan_type=ChannelType.proxy,
|
||||
temp_buffer="scratch",
|
||||
temp_buffer_index=0,
|
||||
)
|
||||
|
||||
for r in range(size):
|
||||
c = chunk(r, "scratch", 1)
|
||||
c.copy_packet(r, Buffer.output, 0, sendtb=0)
|
||||
|
||||
Json()
|
||||
Check()
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("instances", type=int, help="number of instances")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
send_recv(args.instances)
|
||||
56
python/examples/send_recv_proxy.py
Normal file
56
python/examples/send_recv_proxy.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
from mscclpp.language import *
|
||||
from mscclpp.language.buffer import Buffer
|
||||
from mscclpp.language.collectives import SendRecv
|
||||
from mscclpp.language.types import ChannelType
|
||||
|
||||
|
||||
def send_recv(instances):
|
||||
"""
|
||||
Send and receive data between two ranks using proxy channels.
|
||||
steps:
|
||||
1. Each rank sends a chunk to the other rank's scratch buffer and signals the other rank that the data has been sent.
|
||||
2. Wait for the data to be received then copy it to the output buffer.
|
||||
"""
|
||||
size = 2
|
||||
chunksperloop = 1
|
||||
collective = SendRecv(size, chunksperloop, False)
|
||||
with MSCCLPPProgram(
|
||||
"send_recv",
|
||||
collective,
|
||||
size,
|
||||
instances,
|
||||
):
|
||||
for r in range(size):
|
||||
for nghr in range(size):
|
||||
if nghr == r:
|
||||
continue
|
||||
c = chunk(r, Buffer.input, 0)
|
||||
c.put(
|
||||
nghr,
|
||||
"scratch",
|
||||
1,
|
||||
sendtb=0,
|
||||
chan_type=ChannelType.proxy,
|
||||
)
|
||||
c.signal(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.proxy)
|
||||
c.flush(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.proxy)
|
||||
|
||||
for r in range(size):
|
||||
c = chunk(r, "scratch", 1)
|
||||
c.wait(1 - r, Buffer.input, 0, recvtb=0, chan_type=ChannelType.proxy)
|
||||
c.copy(r, Buffer.output, 0, sendtb=0)
|
||||
|
||||
Json()
|
||||
Check()
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("instances", type=int, help="number of instances")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
send_recv(args.instances)
|
||||
4
python/mscclpp/language/__init__.py
Normal file
4
python/mscclpp/language/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from mscclpp.language.program import MSCCLPPProgram, Json, Check, chunk, rank
|
||||
58
python/mscclpp/language/buffer.py
Normal file
58
python/mscclpp/language/buffer.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
# Scratch buffer slice with manual indexing
|
||||
class BufferSlice:
|
||||
def __init__(self, buf, name):
|
||||
self.name = name
|
||||
self.buf = buf
|
||||
self.offset = -1 # Offset into the global scratch buffer
|
||||
self.chunks = []
|
||||
|
||||
# Returns the global index into the scratch buffer
|
||||
def get_global_index(self, index):
|
||||
assert self.offset > -1, "set_offset needs to be called first"
|
||||
return self.offset + index
|
||||
|
||||
def get_buffer(self):
|
||||
return self.buf
|
||||
|
||||
def instance_size(self):
|
||||
return len(self.chunks)
|
||||
|
||||
def set_offset(self, offset):
|
||||
self.offset = offset
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.chunks[index]
|
||||
|
||||
def __setitem__(self, index, value):
|
||||
current_size = len(self.chunks)
|
||||
while index > current_size:
|
||||
self.chunks.append(None)
|
||||
current_size = len(self.chunks)
|
||||
if index == current_size:
|
||||
self.chunks.append(value)
|
||||
else:
|
||||
self.chunks[index] = value
|
||||
|
||||
def __len__(self):
|
||||
return len(self.chunks)
|
||||
|
||||
|
||||
class Buffer(Enum):
|
||||
input = "i"
|
||||
output = "o"
|
||||
scratch = "s"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.value < other.value
|
||||
|
||||
def __gt__(self, other):
|
||||
return self.value < other.value
|
||||
64
python/mscclpp/language/chunk.py
Normal file
64
python/mscclpp/language/chunk.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chunk:
|
||||
origin_rank: int # Rank the chunk initially started at
|
||||
origin_index: int # Index the chunk initially started at
|
||||
dst_rank: int = -1
|
||||
dst_index: int = -1
|
||||
|
||||
def reduce(self, dst, chunk):
|
||||
if type(chunk) is ReduceChunk:
|
||||
return chunk.reduce(dst, self)
|
||||
elif type(chunk) is Chunk:
|
||||
chunks = [self, chunk]
|
||||
return ReduceChunk(dst, chunks)
|
||||
else:
|
||||
raise ValueError("Trying to reduce with chunk of None")
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.origin_rank, self.origin_index))
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
type(other) is Chunk and self.origin_rank == other.origin_rank and self.origin_index == other.origin_index
|
||||
)
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.origin_rank < other.origin_rank or (
|
||||
self.origin_rank == other.origin_rank and self.origin_index < other.origin_index
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReduceChunk:
|
||||
creation_rank: int # Rank the Reduce Chunk is created. Necessary since the same ReduceChunk can be created on multiple ranks independently
|
||||
chunks: list # List of chunks reduced
|
||||
|
||||
def reduce(self, dst, chunk):
|
||||
if type(chunk) is ReduceChunk:
|
||||
chunks = self.chunks + chunk.chunks
|
||||
elif type(chunk) is Chunk:
|
||||
chunks = self.chunks + [chunk]
|
||||
else:
|
||||
raise ValueError("Trying to reduce with chunk of None")
|
||||
return ReduceChunk(self.creation_rank, chunks)
|
||||
|
||||
def sort(self):
|
||||
self.chunks.sort()
|
||||
|
||||
def __hash__(self):
|
||||
self.sort()
|
||||
return hash((self.creation_rank,) + tuple(self.chunks))
|
||||
|
||||
# Two reduce chunks are equal if they contain the same list of
|
||||
# chunks being reduced
|
||||
def __eq__(self, other):
|
||||
self.sort()
|
||||
other.sort()
|
||||
return self.chunks == other.chunks
|
||||
339
python/mscclpp/language/collectives.py
Normal file
339
python/mscclpp/language/collectives.py
Normal file
@@ -0,0 +1,339 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from mscclpp.language.buffer import Buffer
|
||||
from mscclpp.language.chunk import Chunk, ReduceChunk
|
||||
|
||||
|
||||
class Collective:
|
||||
|
||||
def __init__(self, num_ranks, chunk_factor, inplace, num_ranks_per_node=-1, **kwargs):
|
||||
self.num_ranks = num_ranks
|
||||
self.chunk_factor = chunk_factor
|
||||
self.inplace = inplace
|
||||
self.name = "custom"
|
||||
# Divide the buffer into num_chunk_groups group
|
||||
if num_ranks_per_node == -1:
|
||||
self.num_ranks_per_node = num_ranks
|
||||
else:
|
||||
self.num_ranks_per_node = num_ranks_per_node
|
||||
|
||||
# kwargs
|
||||
# Number of chunk groups: which means we will group n chunks into m groups.
|
||||
# We will guarantee that the group size is the same.
|
||||
# But in the same group, the chunk size may be different due to group size
|
||||
# can not be divided by the number of chunks.
|
||||
self.num_chunk_groups = kwargs.get("num_chunk_groups", 1)
|
||||
|
||||
def init_buffers(self):
|
||||
pass
|
||||
|
||||
def check(self, prog):
|
||||
pass
|
||||
|
||||
def get_buffer_index(self, rank, buffer, index):
|
||||
return buffer, index
|
||||
|
||||
|
||||
class AllToAll(Collective):
|
||||
|
||||
def __init__(self, num_ranks, chunk_factor, inplace):
|
||||
Collective.__init__(self, num_ranks, chunk_factor, inplace)
|
||||
self.name = "alltoall"
|
||||
|
||||
def init_buffers(self):
|
||||
chunks_per_node = self.num_ranks * self.chunk_factor
|
||||
rank_buffers = []
|
||||
for r in range(self.num_ranks):
|
||||
input_buffer = [None] * chunks_per_node
|
||||
output_buffer = [None] * chunks_per_node
|
||||
for index in range(chunks_per_node):
|
||||
chunk = Chunk(r, index, index // self.chunk_factor, index % self.chunk_factor + r * self.chunk_factor)
|
||||
input_buffer[index] = chunk
|
||||
if self.inplace:
|
||||
buffers = {Buffer.input: input_buffer, Buffer.output: input_buffer}
|
||||
else:
|
||||
buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer}
|
||||
rank_buffers.append(buffers)
|
||||
return rank_buffers
|
||||
|
||||
# Expected output buffer for alltoall
|
||||
def check(self, prog):
|
||||
chunks_per_node = self.num_ranks * self.chunk_factor
|
||||
correct = True
|
||||
for r in range(self.num_ranks):
|
||||
output = prog.buffers[r][Buffer.output]
|
||||
for i in range(self.num_ranks):
|
||||
for ch in range(self.chunk_factor):
|
||||
index = ch + i * self.chunk_factor
|
||||
chunk = output[index]
|
||||
expected_origin_index = ch + r * self.chunk_factor
|
||||
if chunk is None or chunk.origin_rank != i or chunk.origin_index != expected_origin_index:
|
||||
print(
|
||||
f"Rank {r} chunk {index} is incorrect should be chunk({i},{expected_origin_index}) given {chunk}"
|
||||
)
|
||||
correct = False
|
||||
return correct
|
||||
|
||||
|
||||
class AllGather(Collective):
|
||||
def __init__(self, num_ranks, chunk_factor, inplace):
|
||||
Collective.__init__(self, num_ranks, chunk_factor, inplace)
|
||||
self.name = "allgather"
|
||||
|
||||
# Initializes input buffer for an allgather
|
||||
def init_buffers(self):
|
||||
rank_buffers = []
|
||||
if self.inplace:
|
||||
# Inplace AllGather only uses the output buffer
|
||||
for r in range(self.num_ranks):
|
||||
output_buffer = [None] * (self.num_ranks * self.chunk_factor)
|
||||
for rank in range(self.num_ranks):
|
||||
for ch in range(self.chunk_factor):
|
||||
output_buffer[rank * self.chunk_factor + ch] = Chunk(
|
||||
rank, ch, -1, rank * self.chunk_factor + ch
|
||||
)
|
||||
buffers = {
|
||||
Buffer.input: output_buffer[r * self.chunk_factor : (r + 1) * self.chunk_factor],
|
||||
Buffer.output: output_buffer,
|
||||
}
|
||||
rank_buffers.append(buffers)
|
||||
else:
|
||||
for r in range(self.num_ranks):
|
||||
input_buffer = [None] * self.chunk_factor
|
||||
output_buffer = [None] * (self.num_ranks * self.chunk_factor)
|
||||
for ch in range(self.chunk_factor):
|
||||
input_buffer[ch] = Chunk(r, ch, -1, r * self.chunk_factor + ch)
|
||||
buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer}
|
||||
rank_buffers.append(buffers)
|
||||
return rank_buffers
|
||||
|
||||
# Expected output buffer for allgather
|
||||
def check(self, prog):
|
||||
correct = True
|
||||
buf = Buffer.output
|
||||
for r in range(self.num_ranks):
|
||||
output = prog.buffers[r][buf]
|
||||
for i in range(self.num_ranks):
|
||||
for ch in range(self.chunk_factor):
|
||||
index = i * self.chunk_factor + ch
|
||||
chunk = output[index]
|
||||
if chunk is None:
|
||||
print(f"Rank {r} chunk {index} is incorrect should be ({i}, {ch}) given None")
|
||||
correct = False
|
||||
elif chunk.origin_rank != i or chunk.origin_index != ch:
|
||||
print(
|
||||
f"Rank {r} chunk {index} is incorrect should be ({i}, {ch}) given ({chunk.origin_rank}, {chunk.origin_index})"
|
||||
)
|
||||
correct = False
|
||||
return correct
|
||||
|
||||
def get_buffer_index(self, rank, buffer, index):
|
||||
# For inplace AllGathers, the input buffer points into the output buffer
|
||||
if self.inplace and buffer == Buffer.input:
|
||||
return Buffer.output, index + rank * self.chunk_factor
|
||||
else:
|
||||
return buffer, index
|
||||
|
||||
|
||||
class AllReduce(Collective):
|
||||
|
||||
def __init__(self, num_ranks, chunk_factor, inplace, num_ranks_per_node=-1, **kwargs):
|
||||
num_chunk_groups = kwargs.get("num_chunk_groups", num_ranks)
|
||||
Collective.__init__(
|
||||
self, num_ranks, chunk_factor, inplace, num_ranks_per_node, num_chunk_groups=num_chunk_groups
|
||||
)
|
||||
self.name = "allreduce"
|
||||
|
||||
def init_buffers(self):
|
||||
chunks_per_node = self.chunk_factor
|
||||
rank_buffers = []
|
||||
for r in range(self.num_ranks):
|
||||
input_buffer = []
|
||||
output_buffer = [None] * chunks_per_node
|
||||
for c in range(chunks_per_node):
|
||||
# Chunks start at rank r index c, and ends on all ranks (-1) at index r
|
||||
input_buffer.append(Chunk(r, c, -1, c))
|
||||
# Input and output buffer are the same.
|
||||
if self.inplace:
|
||||
buffers = {Buffer.input: input_buffer, Buffer.output: input_buffer}
|
||||
else:
|
||||
buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer}
|
||||
rank_buffers.append(buffers)
|
||||
return rank_buffers
|
||||
|
||||
def check(self, prog):
|
||||
chunks_per_node = self.chunk_factor
|
||||
expected_chunks = []
|
||||
buf = Buffer.input if self.inplace else Buffer.output
|
||||
|
||||
for c in range(chunks_per_node):
|
||||
chunk = ReduceChunk(-1, [])
|
||||
for r in range(self.num_ranks):
|
||||
chunk = chunk.reduce(-1, Chunk(r, c))
|
||||
expected_chunks.append(chunk)
|
||||
|
||||
correct = True
|
||||
for r in range(self.num_ranks):
|
||||
output = prog.buffers[r][buf]
|
||||
for c in range(chunks_per_node):
|
||||
chunk = output[c]
|
||||
if chunk is None or chunk != expected_chunks[c]:
|
||||
print(
|
||||
f"Rank {r} chunk {c} is incorrect should be ReduceChunk index {c} from all ranks, given {chunk}"
|
||||
)
|
||||
correct = False
|
||||
return correct
|
||||
|
||||
def get_buffer_index(self, rank, buffer, index):
|
||||
if self.inplace and buffer == Buffer.output:
|
||||
return Buffer.input, index
|
||||
else:
|
||||
return buffer, index
|
||||
|
||||
|
||||
class ReduceScatter(Collective):
|
||||
def __init__(self, num_ranks, chunk_factor, inplace):
|
||||
Collective.__init__(self, num_ranks, chunk_factor, inplace)
|
||||
self.name = "reducescatter"
|
||||
|
||||
def init_buffers(self):
|
||||
rank_buffers = []
|
||||
for r in range(self.num_ranks):
|
||||
if self.inplace:
|
||||
input_buffer = []
|
||||
for i in range(self.num_ranks):
|
||||
for c in range(self.chunk_factor):
|
||||
input_buffer.append(Chunk(r, i * self.chunk_factor + c, i, c))
|
||||
buffers = {Buffer.input: input_buffer}
|
||||
rank_buffers.append(buffers)
|
||||
else:
|
||||
input_buffer = []
|
||||
output_buffer = [None] * self.chunk_factor
|
||||
for i in range(self.num_ranks):
|
||||
for c in range(self.chunk_factor):
|
||||
input_buffer.append(Chunk(r, i * self.chunk_factor + c, i, c))
|
||||
buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer}
|
||||
rank_buffers.append(buffers)
|
||||
return rank_buffers
|
||||
|
||||
def check(self, prog):
|
||||
expected_chunks = []
|
||||
buf = Buffer.input if self.inplace else Buffer.output
|
||||
for c in range(self.num_ranks * self.chunk_factor):
|
||||
chunk = ReduceChunk(-1, [])
|
||||
for r in range(self.num_ranks):
|
||||
chunk = chunk.reduce(-1, Chunk(r, c))
|
||||
expected_chunks.append(chunk)
|
||||
|
||||
correct = True
|
||||
for r in range(self.num_ranks):
|
||||
output = prog.buffers[r][buf]
|
||||
for c in range(self.chunk_factor):
|
||||
correct_idx = r * self.chunk_factor + c
|
||||
if self.inplace:
|
||||
c = correct_idx
|
||||
chunk = output[c]
|
||||
if chunk is None or chunk != expected_chunks[correct_idx]:
|
||||
print(f"Rank {r} chunk {c} is incorrect should be index {correct_idx} from all ranks given {chunk}")
|
||||
correct = False
|
||||
return correct
|
||||
|
||||
def get_buffer_index(self, rank, buffer, index):
|
||||
# For inplace ReduceScatter the output buffer is a pointer into the input buffer
|
||||
if self.inplace and buffer == Buffer.output:
|
||||
return Buffer.input, index + rank * self.chunk_factor
|
||||
else:
|
||||
return buffer, index
|
||||
|
||||
|
||||
# SendRecv is a collective that sends a chunk from one rank to another
|
||||
# It is used to test the correctness of the send and receive instructions
|
||||
class SendRecv(Collective):
|
||||
def __init__(self, num_ranks, chunk_factor, inplace):
|
||||
assert num_ranks == 2, "SendRecv only supports 2 ranks"
|
||||
Collective.__init__(self, num_ranks, chunk_factor, inplace)
|
||||
self.name = "sendrecv"
|
||||
|
||||
def init_buffers(self):
|
||||
rank_buffers = []
|
||||
for r in range(self.num_ranks):
|
||||
input_buffer = [None] * self.chunk_factor
|
||||
output_buffer = [None] * self.chunk_factor
|
||||
for c in range(self.chunk_factor):
|
||||
input_buffer[c] = Chunk(r, c, -1, c)
|
||||
buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer}
|
||||
rank_buffers.append(buffers)
|
||||
return rank_buffers
|
||||
|
||||
def check(self, prog):
|
||||
correct = True
|
||||
buff_type = Buffer.input if self.inplace else Buffer.output
|
||||
for r in range(self.num_ranks):
|
||||
output = prog.buffers[r][buff_type]
|
||||
for c in range(self.chunk_factor):
|
||||
chunk = output[c]
|
||||
if chunk is None or chunk.origin_rank != 1 - r or chunk.origin_index != c:
|
||||
print(f"Rank {r} chunk {c} is incorrect should be ({1 - r}, {c}) given {chunk}")
|
||||
correct = False
|
||||
|
||||
return correct
|
||||
|
||||
def get_buffer_index(self, rank, buffer, index):
|
||||
if self.inplace and buffer == Buffer.output:
|
||||
return Buffer.input, index
|
||||
return buffer, index
|
||||
|
||||
|
||||
class Broadcast(Collective):
|
||||
def __init__(self, num_ranks, chunk_factor, inplace, root):
|
||||
Collective.__init__(self, num_ranks, chunk_factor, inplace, root)
|
||||
self.name = "broadcast"
|
||||
self.root = root
|
||||
|
||||
# Initializes input buffer for an broadcast
|
||||
def init_buffers(self):
|
||||
rank_buffers = []
|
||||
if self.inplace:
|
||||
# Inplace broadcast only uses the input buffer
|
||||
for r in range(self.num_ranks):
|
||||
input_buffer = [None] * (self.chunk_factor)
|
||||
for ch in range(self.chunk_factor):
|
||||
input_buffer[ch] = Chunk(self.root, ch, -1, ch)
|
||||
buffers = {
|
||||
Buffer.input: input_buffer,
|
||||
Buffer.output: input_buffer,
|
||||
}
|
||||
rank_buffers.append(buffers)
|
||||
else:
|
||||
for r in range(self.num_ranks):
|
||||
input_buffer = [None] * self.chunk_factor
|
||||
output_buffer = [None] * self.chunk_factor
|
||||
if r == self.root:
|
||||
for ch in range(self.chunk_factor):
|
||||
input_buffer[ch] = Chunk(self.root, ch, -1, ch)
|
||||
buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer}
|
||||
rank_buffers.append(buffers)
|
||||
return rank_buffers
|
||||
|
||||
# Expected output buffer for broadcast
|
||||
def check(self, prog):
|
||||
correct = True
|
||||
buf = Buffer.output
|
||||
for r in range(self.num_ranks):
|
||||
output = prog.buffers[0][buf]
|
||||
for ch in range(self.chunk_factor):
|
||||
index = ch
|
||||
chunk = output[index]
|
||||
if chunk is None:
|
||||
print(f"Rank {r} chunk {index} is incorrect should be ({i}, {ch}) given None")
|
||||
correct = False
|
||||
elif chunk.origin_rank != self.root or chunk.origin_index != ch:
|
||||
print(
|
||||
f"Rank {r} chunk {index} is incorrect should be ({self.root}, {ch}) given ({chunk.origin_rank}, {chunk.origin_index})"
|
||||
)
|
||||
correct = False
|
||||
return correct
|
||||
|
||||
def get_buffer_index(self, rank, buffer, index):
|
||||
return buffer, index
|
||||
6
python/mscclpp/language/dag/__init__.py
Normal file
6
python/mscclpp/language/dag/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from mscclpp.language.dag.instruction_dag import InstructionDAG
|
||||
from mscclpp.language.dag.lower import DagLower
|
||||
from mscclpp.language.dag.optimizer import DagOptimizer
|
||||
373
python/mscclpp/language/dag/instruction_dag.py
Normal file
373
python/mscclpp/language/dag/instruction_dag.py
Normal file
@@ -0,0 +1,373 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from collections import defaultdict
|
||||
from mscclpp.language.buffer import Buffer
|
||||
from mscclpp.language.types import (
|
||||
Channel,
|
||||
ChannelType,
|
||||
ChunkRef,
|
||||
Instruction,
|
||||
Op,
|
||||
)
|
||||
|
||||
|
||||
class InstructionDAG:
|
||||
def __init__(self, num_ranks: int, buffers: list):
|
||||
self.num_ranks = num_ranks
|
||||
self.buffers = buffers
|
||||
# State for the actual instruction DAG
|
||||
self.operations = {} # slot -> operations
|
||||
self.last_writer = {} # slot -> last writing op
|
||||
self.last_readers = defaultdict(list) # slot -> list of last reading ops
|
||||
# State for the MSCCLPP-IR
|
||||
self.tbs = []
|
||||
for _ in range(num_ranks):
|
||||
self.tbs.append({})
|
||||
self.tb_mapping = {}
|
||||
self.num_channels = [1] * num_ranks
|
||||
self.tb_steps = [{} for _ in range(num_ranks)]
|
||||
|
||||
def convert_set_list(self):
|
||||
ops = []
|
||||
visited = set()
|
||||
for slot, op in self.operations.items():
|
||||
if op.inst == Instruction.start:
|
||||
op.next = list(op.next)
|
||||
for o in op.next:
|
||||
ops.append(o)
|
||||
elif op.inst != Instruction.copy:
|
||||
ops.append(op)
|
||||
|
||||
while len(ops) > 0:
|
||||
op = ops[0]
|
||||
if op not in visited:
|
||||
visited.add(op)
|
||||
op.next = list(op.next)
|
||||
ops = ops[1:] + op.next
|
||||
else:
|
||||
ops = ops[1:]
|
||||
return visited
|
||||
|
||||
def complete_channels(self):
|
||||
send_op = [Instruction.put, Instruction.signal, Instruction.put_packet]
|
||||
recv_op = [Instruction.wait, Instruction.get, Instruction.read_reduce_copy]
|
||||
group_send_op = [Instruction.group_store]
|
||||
group_recv_op = [Instruction.group_load_reduce]
|
||||
for rank, rank_tbs in enumerate(self.tbs):
|
||||
for tbid, tb in rank_tbs.items():
|
||||
chans = set()
|
||||
for op in tb.ops:
|
||||
if op.inst == Instruction.barrier:
|
||||
continue
|
||||
if op.src != None:
|
||||
src_buffer = (
|
||||
Buffer.scratch
|
||||
if op.src.buffer is not Buffer.input and op.src.buffer is not Buffer.output
|
||||
else op.src.buffer
|
||||
)
|
||||
if op.dst != None:
|
||||
dst_buffer = (
|
||||
Buffer.scratch
|
||||
if op.dst.buffer is not Buffer.input and op.dst.buffer is not Buffer.output
|
||||
else op.dst.buffer
|
||||
)
|
||||
if op.channel_type == ChannelType.nvls:
|
||||
if op.inst in group_send_op:
|
||||
ranks = [dst[0].rank for dst in op.dsts]
|
||||
chan = Channel(src_buffer, dst_buffer, op.channel_type, ranks)
|
||||
chans.add(chan)
|
||||
elif op.inst in group_recv_op:
|
||||
ranks = [src[0].rank for src in op.srcs]
|
||||
chan = Channel(src_buffer, dst_buffer, op.channel_type, ranks)
|
||||
chans.add(chan)
|
||||
else:
|
||||
if op.inst in send_op:
|
||||
chan = Channel(src_buffer, dst_buffer, op.channel_type, op.dst.rank)
|
||||
chans.add(chan)
|
||||
elif op.inst in recv_op:
|
||||
chan = Channel(src_buffer, dst_buffer, op.channel_type, op.src.rank)
|
||||
chans.add(chan)
|
||||
tb.channels = list(chans)
|
||||
|
||||
# InstructionDAG - builds the roots of the DAG
|
||||
def add_start(self, rank, buffer, index, ref):
|
||||
slot = (rank, buffer, index)
|
||||
op = Op(Instruction.start, rank, ref, ref, next=set(), prev=set())
|
||||
self.operations[slot] = op
|
||||
self.last_writer[slot] = op
|
||||
|
||||
# InstructionDAG - adds a copy node
|
||||
def add_copy(self, rank, send_ref, recv_ref, tb, trans_from_packet=False, trans_to_packet=False):
|
||||
tb_step = self._get_tb_step(rank, tb)
|
||||
if trans_from_packet:
|
||||
op = Op(Instruction.copy_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step)
|
||||
elif trans_to_packet:
|
||||
op = Op(
|
||||
Instruction.transform_to_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step
|
||||
)
|
||||
else:
|
||||
op = Op(Instruction.copy, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step)
|
||||
dstbuffer = recv_ref.buffer
|
||||
dstindex = recv_ref.index
|
||||
srcbuffer = send_ref.buffer
|
||||
srcindex = send_ref.index
|
||||
size = recv_ref.size
|
||||
# Sending part of copy [Read]
|
||||
self._read(rank, srcbuffer, srcindex, size, op)
|
||||
# Receiving part of copy [Write]
|
||||
self._write(rank, dstbuffer, dstindex, size, op)
|
||||
return op
|
||||
|
||||
# InstructionDAG - adds a redduce node
|
||||
def add_reduce(self, rank, send_ref, recv_ref, tb, use_packet=False):
|
||||
tb_step = self._get_tb_step(rank, tb)
|
||||
if use_packet:
|
||||
op = Op(Instruction.reduce_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step)
|
||||
else:
|
||||
op = Op(Instruction.reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, step=tb_step)
|
||||
dstbuffer = recv_ref.buffer
|
||||
dstindex = recv_ref.index
|
||||
srcbuffer = send_ref.buffer
|
||||
srcindex = send_ref.index
|
||||
size = recv_ref.size
|
||||
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
|
||||
# Sending part of reduce
|
||||
self._read(rank, srcbuffer, srcindex, size, op)
|
||||
# Reduce part of copy
|
||||
self._write(rank, dstbuffer, dstindex, size, op, read=True)
|
||||
return op
|
||||
|
||||
# InstructionDAG - adds a put node
|
||||
def add_put(self, rank, send_ref, recv_ref, tb, ch_type, use_packet=False):
|
||||
tb_step = self._get_tb_step(rank, tb)
|
||||
if use_packet:
|
||||
op = Op(
|
||||
Instruction.put_packet,
|
||||
rank,
|
||||
send_ref,
|
||||
recv_ref,
|
||||
next=set(),
|
||||
prev=set(),
|
||||
tb=tb,
|
||||
channel_type=ch_type,
|
||||
step=tb_step,
|
||||
)
|
||||
else:
|
||||
op = Op(
|
||||
Instruction.put,
|
||||
rank,
|
||||
send_ref,
|
||||
recv_ref,
|
||||
next=set(),
|
||||
prev=set(),
|
||||
tb=tb,
|
||||
channel_type=ch_type,
|
||||
step=tb_step,
|
||||
)
|
||||
buffer = send_ref.buffer
|
||||
index = send_ref.index
|
||||
size = send_ref.size
|
||||
self._read(rank, buffer, index, size, op)
|
||||
op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step))
|
||||
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
|
||||
return op
|
||||
|
||||
def add_get(self, rank, send_ref, recv_ref, tb, ch_type):
|
||||
tb_step = self._get_tb_step(rank, tb)
|
||||
op = Op(
|
||||
Instruction.get, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step
|
||||
)
|
||||
buffer = recv_ref.buffer
|
||||
index = recv_ref.index
|
||||
size = recv_ref.size
|
||||
self._write(rank, buffer, index, size, op)
|
||||
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
|
||||
op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step))
|
||||
return op
|
||||
|
||||
# InstructionDAG - adds a signal node.
|
||||
def add_signal(self, rank, send_ref, recv_ref, tb, ch_type):
|
||||
tb_step = self._get_tb_step(rank, tb)
|
||||
op = Op(
|
||||
Instruction.signal,
|
||||
rank,
|
||||
send_ref,
|
||||
recv_ref,
|
||||
next=set(),
|
||||
prev=set(),
|
||||
tb=tb,
|
||||
channel_type=ch_type,
|
||||
step=tb_step,
|
||||
)
|
||||
buffer = send_ref.buffer
|
||||
index = send_ref.index
|
||||
size = send_ref.size
|
||||
# treat signal as a write. signal acts as a barrier for the next instruction which prevents the
|
||||
# below instructions to be scheduled above the signal instruction.
|
||||
self._write(rank, buffer, index, size, op)
|
||||
op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step))
|
||||
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
|
||||
return op
|
||||
|
||||
def add_flush(self, rank, send_ref, recv_ref, tb):
|
||||
tb_step = self._get_tb_step(rank, tb)
|
||||
op = Op(
|
||||
Instruction.flush,
|
||||
rank,
|
||||
send_ref,
|
||||
recv_ref,
|
||||
next=set(),
|
||||
prev=set(),
|
||||
tb=tb,
|
||||
channel_type=ChannelType.proxy,
|
||||
step=tb_step,
|
||||
)
|
||||
buffer = send_ref.buffer
|
||||
index = send_ref.index
|
||||
size = send_ref.size
|
||||
self._read(rank, buffer, index, size, op)
|
||||
op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step))
|
||||
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
|
||||
return op
|
||||
|
||||
def add_wait(self, rank, dst_ref, src_ref, tb, ch_type):
|
||||
tb_step = self._get_tb_step(rank, tb)
|
||||
op = Op(
|
||||
Instruction.wait, rank, src_ref, dst_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step
|
||||
)
|
||||
buffer = dst_ref.buffer
|
||||
index = dst_ref.index
|
||||
size = dst_ref.size
|
||||
self._write(rank, buffer, index, size, op)
|
||||
op.srcs.append((ChunkRef(src_ref.rank, src_ref.buffer, src_ref.index, src_ref.size), tb_step))
|
||||
op.dsts.append((ChunkRef(dst_ref.rank, dst_ref.buffer, dst_ref.index, dst_ref.size), tb_step))
|
||||
return op
|
||||
|
||||
def add_read_reduce(self, rank, send_ref, recv_ref, tb, ch_type):
|
||||
tb_step = self._get_tb_step(rank, tb)
|
||||
op = Op(
|
||||
Instruction.read_reduce_copy,
|
||||
rank,
|
||||
send_ref,
|
||||
recv_ref,
|
||||
next=set(),
|
||||
prev=set(),
|
||||
tb=tb,
|
||||
channel_type=ch_type,
|
||||
step=tb_step,
|
||||
)
|
||||
buffer = recv_ref.buffer
|
||||
index = recv_ref.index
|
||||
size = recv_ref.size
|
||||
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
|
||||
self._write(rank, buffer, index, size, op, read=True)
|
||||
return op
|
||||
|
||||
def add_barrier(self, rank, tb_list, barrier_id):
|
||||
buffers = self.buffers[rank]
|
||||
for tb in tb_list:
|
||||
tb_step = self._get_tb_step(rank, tb)
|
||||
extra = {"tb_list": tb_list, "barrier_id": barrier_id}
|
||||
op = Op(Instruction.barrier, rank, None, None, next=set(), prev=set(), tb=tb, step=tb_step, extra=extra)
|
||||
for buffer_type, buffer in buffers.items():
|
||||
self._write(rank, buffer_type, 0, len(buffer), op)
|
||||
|
||||
def add_group_load_reduce(self, rank, send_refs, recv_ref, tb, ch_type):
|
||||
tb_step = self._get_tb_step(rank, tb)
|
||||
op = Op(
|
||||
Instruction.group_load_reduce,
|
||||
rank,
|
||||
recv_ref,
|
||||
recv_ref,
|
||||
next=set(),
|
||||
prev=set(),
|
||||
tb=tb,
|
||||
channel_type=ch_type,
|
||||
step=tb_step,
|
||||
)
|
||||
# treat recv_ref as src for group_load_reduce
|
||||
op.srcs.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step))
|
||||
for send_ref in send_refs:
|
||||
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
|
||||
buffer = recv_ref.buffer
|
||||
index = recv_ref.index
|
||||
size = recv_ref.size
|
||||
self._write(rank, buffer, index, size, op, read=True)
|
||||
|
||||
def add_group_store(self, rank, send_ref, recv_refs, tb, ch_type):
|
||||
tb_step = self._get_tb_step(rank, tb)
|
||||
op = Op(
|
||||
Instruction.group_store,
|
||||
rank,
|
||||
send_ref,
|
||||
send_ref,
|
||||
next=set(),
|
||||
prev=set(),
|
||||
tb=tb,
|
||||
channel_type=ch_type,
|
||||
step=tb_step,
|
||||
)
|
||||
# treat send_ref as dst for group_store
|
||||
op.dsts.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
|
||||
for recv_ref in recv_refs:
|
||||
op.dsts.append((ChunkRef(recv_ref.rank, recv_ref.buffer, recv_ref.index, recv_ref.size), tb_step))
|
||||
buffer = send_ref.buffer
|
||||
index = send_ref.index
|
||||
size = send_ref.size
|
||||
self._read(rank, buffer, index, size, op)
|
||||
return op
|
||||
|
||||
def _get_tb_step(self, rank: int, tb: int):
|
||||
if tb in self.tb_steps[rank]:
|
||||
self.tb_steps[rank][tb] += 1
|
||||
return self.tb_steps[rank][tb]
|
||||
else:
|
||||
self.tb_steps[rank][tb] = 0
|
||||
return 0
|
||||
|
||||
# InstructionDAG helper - identifies the dependencies for a write-type operation (recv, copy, rrc, reduce)
|
||||
def _write(self, rank, buffer, index, size, op, read=False):
|
||||
prev_ops = set()
|
||||
for i in range(index, index + size):
|
||||
slot = (rank, buffer, i)
|
||||
if read:
|
||||
assert slot in self.last_writer, f"Destination slot has never been written before a reduce {op}"
|
||||
|
||||
# First write to this slot
|
||||
if slot not in self.operations:
|
||||
self.operations[slot] = op
|
||||
|
||||
# If there are active readers - these are the previous operations
|
||||
# Else the previous operation is the last write (if there is one)
|
||||
readers = self.last_readers[slot]
|
||||
if len(readers) > 0:
|
||||
prev_ops.update(readers)
|
||||
elif slot in self.last_writer:
|
||||
prev_ops.add(self.last_writer[slot])
|
||||
|
||||
# Set the last_writer to this op, and clear all readers
|
||||
self.last_writer[slot] = op
|
||||
self.last_readers[slot] = []
|
||||
|
||||
# Update the next pointer of the previous ops
|
||||
for prev_op in prev_ops:
|
||||
prev_op.next.add(op)
|
||||
op.prev.add(prev_op)
|
||||
|
||||
# InstructionDAG helper - identifies the dependencies for read-type operations (send, copy, reduce)
|
||||
def _read(self, rank, buffer, index, size, op):
|
||||
prev_ops = set()
|
||||
for i in range(index, index + size):
|
||||
slot = (rank, buffer, i)
|
||||
assert slot in self.last_writer, f"Slot has never been written before a read-type {op}"
|
||||
# The previous operation for a reader is the last write to the slot
|
||||
writer = self.last_writer[slot]
|
||||
prev_ops.add(writer)
|
||||
self.last_readers[slot].append(op)
|
||||
|
||||
# Update the next pointer of the previous ops
|
||||
for prev_op in prev_ops:
|
||||
prev_op.next.add(op)
|
||||
op.prev.add(prev_op)
|
||||
162
python/mscclpp/language/dag/lower.py
Normal file
162
python/mscclpp/language/dag/lower.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import copy
|
||||
from typing import List
|
||||
from mscclpp.language.buffer import Buffer
|
||||
from mscclpp.language.dag.instruction_dag import InstructionDAG
|
||||
from mscclpp.language.types import ChunkRef, Gpu, Instruction, Op, ReplicationPolicy, Threadblock
|
||||
|
||||
|
||||
class DagLower:
|
||||
def __init__(self, dag: InstructionDAG):
|
||||
self.dag = dag
|
||||
self.instanced_tbs = []
|
||||
|
||||
def lower(self, instances: int, replication_policy: ReplicationPolicy):
|
||||
self._infer_dependencies()
|
||||
self._lower_buffers(instances)
|
||||
self._replicate(instances, replication_policy)
|
||||
return self._lower_tbs()
|
||||
|
||||
def _replicate(self, instances: int, replication_policy: ReplicationPolicy):
|
||||
# update op step
|
||||
for rank, rank_tbs in enumerate(self.dag.tbs):
|
||||
for _, tb in rank_tbs.items():
|
||||
for id, op in enumerate(tb.ops):
|
||||
op.step = id
|
||||
|
||||
if instances == 1:
|
||||
self.instanced_tbs = self.dag.tbs
|
||||
return
|
||||
|
||||
self.instanced_tbs = []
|
||||
for _ in range(self.dag.num_ranks):
|
||||
self.instanced_tbs.append({})
|
||||
|
||||
def get_new_index(rank, buffer, index, size, i):
|
||||
if replication_policy == ReplicationPolicy.interleaved:
|
||||
return index * instances + i * size
|
||||
return len(self.dag.buffers[rank][buffer]) * i + index
|
||||
|
||||
def get_instance_ref(ref):
|
||||
if ref is None:
|
||||
return None
|
||||
iindex = get_new_index(ref.rank, ref.buffer, ref.index, ref.size, i)
|
||||
iref = ChunkRef(ref.rank, ref.buffer, iindex, ref.size)
|
||||
return iref
|
||||
|
||||
def update_extra(op, ori_op):
|
||||
if op.inst == Instruction.barrier:
|
||||
tb_list = ori_op.extra["tb_list"]
|
||||
new_tb_list = [tb * instances + i for tb in tb_list]
|
||||
op.extra["tb_list"] = new_tb_list
|
||||
op.extra["barrier_id"] = ori_op.extra["barrier_id"] * instances + i
|
||||
|
||||
for i in range(instances):
|
||||
# Generate all the threadblocks and ops
|
||||
for rank, rank_tbs in enumerate(self.dag.tbs):
|
||||
# rank_channels = self.num_channels[rank]
|
||||
for tbid, tb in rank_tbs.items():
|
||||
itbid = tbid * instances + i
|
||||
itb = Threadblock(id=itbid)
|
||||
itb.ops = [None] * len(tb.ops)
|
||||
for s, op in enumerate(tb.ops):
|
||||
isrc = get_instance_ref(op.src)
|
||||
idst = get_instance_ref(op.dst)
|
||||
idepends = []
|
||||
# Note: We don't need the fill out the rest of the metadata since replication is the last optimization
|
||||
iop = Op(
|
||||
op.inst,
|
||||
op.rank,
|
||||
isrc,
|
||||
idst,
|
||||
idepends,
|
||||
op.step,
|
||||
itbid,
|
||||
channel_type=op.channel_type,
|
||||
extra=copy.deepcopy(op.extra),
|
||||
)
|
||||
update_extra(iop, op)
|
||||
itb.ops[s] = iop
|
||||
for src, step in op.srcs:
|
||||
isrc = get_instance_ref(src)
|
||||
iop.srcs.append((isrc, step))
|
||||
for dst, step in op.dsts:
|
||||
idst = get_instance_ref(dst)
|
||||
iop.dsts.append((idst, step))
|
||||
for chan in tb.channels:
|
||||
itb.channels.append(chan)
|
||||
self.instanced_tbs[op.rank][itbid] = itb
|
||||
|
||||
# Redo dependency analysis
|
||||
for rank, rank_tbs in enumerate(self.dag.tbs):
|
||||
for tbid, tb in rank_tbs.items():
|
||||
for i in range(instances):
|
||||
itbid = tbid * instances + i
|
||||
itb = self.instanced_tbs[rank][itbid]
|
||||
for op, iop in zip(tb.ops, itb.ops):
|
||||
iop.depends = [None] * len(op.depends)
|
||||
for s, dep in enumerate(op.depends):
|
||||
dep_tbid = dep.tb
|
||||
dep_itbid = dep_tbid * instances + i
|
||||
dep_step = dep.step
|
||||
iop.depends[s] = self.instanced_tbs[op.rank][dep_itbid].ops[dep_step]
|
||||
|
||||
# Convert local scratch buffers to index into one global scratch buffer
|
||||
|
||||
def _lower_chunk(self, chunk):
|
||||
if chunk is not None and chunk.buffer is not Buffer.input and chunk.buffer is not Buffer.output:
|
||||
buffer = self.dag.buffers[chunk.rank][chunk.buffer].get_buffer()
|
||||
index = self.dag.buffers[chunk.rank][chunk.buffer].get_global_index(chunk.index)
|
||||
return ChunkRef(chunk.rank, buffer, index, chunk.size)
|
||||
return chunk
|
||||
|
||||
# Assigns each scratch buffer an offset into the global scratch buffer
|
||||
def _lower_buffers(self, instances):
|
||||
for rank_buffers in self.dag.buffers:
|
||||
offset = 0
|
||||
for key, buf in rank_buffers.items():
|
||||
if key is not Buffer.input and key is not Buffer.output:
|
||||
buf.set_offset(offset)
|
||||
offset += buf.instance_size() * instances
|
||||
|
||||
def _lower_tbs(self) -> List[Gpu]:
|
||||
gpus = []
|
||||
for rank, rank_tbs in enumerate(self.instanced_tbs):
|
||||
lowered_tbs = {}
|
||||
for tbid, tb in rank_tbs.items():
|
||||
for op in tb.ops:
|
||||
op.src = self._lower_chunk(op.src)
|
||||
op.dst = self._lower_chunk(op.dst)
|
||||
srcs = sorted(op.srcs, key=lambda x: x[1])
|
||||
dsts = sorted(op.dsts, key=lambda x: x[1])
|
||||
op.srcs = [self._lower_chunk(src[0]) for src in srcs]
|
||||
op.dsts = [self._lower_chunk(dst[0]) for dst in dsts]
|
||||
lowered_tbs[tbid] = tb
|
||||
gpus.append(Gpu(rank, list(lowered_tbs.values())))
|
||||
return gpus
|
||||
|
||||
def _infer_dependencies(self):
|
||||
visited = set()
|
||||
for _, op in self.dag.operations.items():
|
||||
if op in visited:
|
||||
continue
|
||||
frontier = [op]
|
||||
while len(frontier) > 0:
|
||||
op = frontier[0]
|
||||
if op in visited:
|
||||
frontier = frontier[1:]
|
||||
continue
|
||||
# Dependencies for every op is the same as the ops that are stored in prev
|
||||
# Filter out dependencies that are satisified by tbs executing ops sequentially
|
||||
# If multiple dependent ops from the same tb keep the one that happens last
|
||||
depends = {}
|
||||
for dep_op in list(op.prev):
|
||||
if dep_op.inst != Instruction.start:
|
||||
tb = dep_op.tb
|
||||
if tb not in depends or dep_op.step > depends[tb].step:
|
||||
depends[tb] = dep_op
|
||||
op.depends = list(depends.values())
|
||||
visited.add(op)
|
||||
frontier = frontier[1:] + op.next
|
||||
405
python/mscclpp/language/dag/optimizer.py
Normal file
405
python/mscclpp/language/dag/optimizer.py
Normal file
@@ -0,0 +1,405 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from mscclpp.language.utils import (
|
||||
buf_dst_src_match,
|
||||
circular_dep_after_merge,
|
||||
merge_op,
|
||||
remove_op,
|
||||
same_chan_type,
|
||||
same_count,
|
||||
same_buf_dst,
|
||||
same_buf_src,
|
||||
same_src_dst_buffer_type,
|
||||
same_tb,
|
||||
all_prevs_visited_after_merge,
|
||||
)
|
||||
from mscclpp.language.dag.instruction_dag import InstructionDAG
|
||||
from mscclpp.language.types import ChunkRef, ChannelType, Instruction, Op, Threadblock
|
||||
|
||||
|
||||
class _InstructionOptimizer:
|
||||
|
||||
def try_merge_same_instructions(
|
||||
self,
|
||||
op: Op,
|
||||
next_op: Op,
|
||||
tb: Threadblock,
|
||||
queue: list,
|
||||
expected_next_inst: Instruction,
|
||||
same_buf_func: callable,
|
||||
) -> bool:
|
||||
"""
|
||||
Attempts to merge two instruction if conditions are met.
|
||||
:param op: The current operation.
|
||||
:param next_op: The next operation to potentially merge with.
|
||||
:param tb: The thread block containing the operations.
|
||||
:param queue: The queue of operations.
|
||||
:param expected_next_inst: The instruction type expected for the next operation.
|
||||
:param same_buf_func: The function to check if the buffer is the same (same_buf_dst or same_buf_src).
|
||||
:return: True if operations are merged, False otherwise.
|
||||
"""
|
||||
if (
|
||||
next_op.inst == expected_next_inst
|
||||
and same_tb(op, next_op)
|
||||
and same_buf_func(op, next_op)
|
||||
and same_count(op, next_op)
|
||||
and same_chan_type(op, next_op)
|
||||
and not circular_dep_after_merge(op, next_op)
|
||||
and all_prevs_visited_after_merge(op, next_op)
|
||||
):
|
||||
# Append the source chunks from next_op
|
||||
op.srcs.append(
|
||||
(
|
||||
ChunkRef(next_op.src.rank, next_op.src.buffer, next_op.src.index, next_op.src.size),
|
||||
next_op.step,
|
||||
)
|
||||
)
|
||||
# For 'signal' and 'wait' instructions, append destination chunks too
|
||||
if expected_next_inst in [Instruction.signal, Instruction.wait, Instruction.flush]:
|
||||
op.dsts.append(
|
||||
(
|
||||
ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size),
|
||||
next_op.step,
|
||||
)
|
||||
)
|
||||
# Merge operations
|
||||
merge_op(op, next_op)
|
||||
tb.ops.remove(next_op)
|
||||
queue.remove(next_op)
|
||||
return True
|
||||
return False
|
||||
|
||||
def try_compact_instructions(
|
||||
self, op: Op, tb: Threadblock, queue: list, inst_type: Instruction, same_src_dst_func: callable
|
||||
) -> bool:
|
||||
"""
|
||||
Try to campact the instructions with the same instruction type. This optimization will
|
||||
compact multiple instructions of the same type into a single instruction.
|
||||
:param op: The current operation.
|
||||
:param seq_op: The sequential operation to merge with.
|
||||
:param tb: The task block containing the operations.
|
||||
:param queue: The queue of operations.
|
||||
:param inst_type: The type of the instruction being processed (get, put, put_packet).
|
||||
:return: True if operations are merged, False otherwise.
|
||||
"""
|
||||
if len(queue) > 1:
|
||||
seq_op = queue[1]
|
||||
if (
|
||||
seq_op.inst == inst_type
|
||||
and same_src_dst_func(op, seq_op)
|
||||
and same_chan_type(op, seq_op)
|
||||
and same_count(op, seq_op)
|
||||
and not circular_dep_after_merge(op, seq_op)
|
||||
and all_prevs_visited_after_merge(op, seq_op)
|
||||
):
|
||||
# Append the source and destination chunks from seq_op
|
||||
op.dsts.append(
|
||||
(
|
||||
ChunkRef(seq_op.dst.rank, seq_op.dst.buffer, seq_op.dst.index, seq_op.dst.size),
|
||||
seq_op.step,
|
||||
)
|
||||
)
|
||||
op.srcs.append(
|
||||
(
|
||||
ChunkRef(seq_op.src.rank, seq_op.src.buffer, seq_op.src.index, seq_op.src.size),
|
||||
seq_op.step,
|
||||
)
|
||||
)
|
||||
merge_op(op, seq_op)
|
||||
tb.ops.remove(seq_op)
|
||||
queue.remove(seq_op)
|
||||
return True
|
||||
return False
|
||||
|
||||
def try_fuse_with_put(self, op: Op, next_op: Op, tb: Threadblock, queue: list) -> bool:
|
||||
"""
|
||||
Attempts to fuse 'put' operations with other operations like read_reduce_copy, reduce, etc.
|
||||
:param op: The current operation.
|
||||
:param next_op: The next operation to potentially merge with.
|
||||
:param tb: The thread block containing the operations.
|
||||
:param queue: The queue of operations.
|
||||
:param inst_type: The type of the instruction being processed.
|
||||
:param chan_type: Channel type if applicable.
|
||||
:return: True if operations are merged, False otherwise.
|
||||
"""
|
||||
if (
|
||||
(next_op.inst == Instruction.put or next_op.inst == Instruction.put_packet)
|
||||
and same_tb(op, next_op)
|
||||
and same_count(op, next_op)
|
||||
and buf_dst_src_match(op, next_op)
|
||||
and next_op.channel_type == ChannelType.sm
|
||||
and (op.channel_type == ChannelType.none or op.channel_type == ChannelType.sm)
|
||||
and not circular_dep_after_merge(op, next_op)
|
||||
and all_prevs_visited_after_merge(op, next_op)
|
||||
):
|
||||
if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer:
|
||||
return False
|
||||
# Adjust instruction type and channel if needed
|
||||
if op.inst == Instruction.read_reduce_copy:
|
||||
op.inst = Instruction.read_reduce_copy_send
|
||||
elif op.inst == Instruction.reduce:
|
||||
op.inst = Instruction.reduce_send
|
||||
op.channel_type = ChannelType.sm
|
||||
elif op.inst == Instruction.reduce_packet:
|
||||
op.inst = Instruction.reduce_send_packet
|
||||
op.channel_type = ChannelType.sm
|
||||
# Append the destination chunk from next_op
|
||||
op.dsts.append(
|
||||
(
|
||||
ChunkRef(next_op.dst.rank, next_op.dst.buffer, next_op.dst.index, next_op.dst.size),
|
||||
next_op.step,
|
||||
)
|
||||
)
|
||||
# Merge operations
|
||||
merge_op(op, next_op)
|
||||
tb.ops.remove(next_op)
|
||||
queue.remove(next_op)
|
||||
return True
|
||||
return False
|
||||
|
||||
def try_fuse_instructions_using_proxy_channel(
|
||||
self, op: Op, next_op: Op, tb: Threadblock, queue: list, expected_next_inst: Instruction
|
||||
) -> bool:
|
||||
"""
|
||||
Attempts to fuse operations which using proxy channel.
|
||||
:param op: The current operation.
|
||||
:param next_op: The next operation to potentially merge with.
|
||||
:param tb: The thread block containing the operations.
|
||||
:param queue: The queue of operations.
|
||||
:param expected_next_inst: The instruction type expected for the next operation.
|
||||
:return: True if operations are merged, False otherwise.
|
||||
"""
|
||||
if (
|
||||
next_op.inst == expected_next_inst
|
||||
and same_tb(op, next_op)
|
||||
and same_count(op, next_op)
|
||||
and same_buf_dst(op, next_op)
|
||||
and same_buf_src(op, next_op)
|
||||
and same_chan_type(op, next_op)
|
||||
and op.channel_type == ChannelType.proxy
|
||||
and not circular_dep_after_merge(op, next_op)
|
||||
and all_prevs_visited_after_merge(op, next_op)
|
||||
):
|
||||
if op.inst == Instruction.put and next_op.inst == Instruction.signal:
|
||||
op.inst = Instruction.put_with_signal
|
||||
elif op.inst == Instruction.put_with_signal and next_op.inst == Instruction.flush:
|
||||
op.inst = Instruction.put_with_signal_and_flush
|
||||
# Merge operations
|
||||
merge_op(op, next_op)
|
||||
tb.ops.remove(next_op)
|
||||
queue.remove(next_op)
|
||||
return True
|
||||
return False
|
||||
|
||||
def try_fuse_with_group_store(self, op: Op, next_op: Op, tb: Threadblock, queue: list) -> bool:
|
||||
"""
|
||||
Attempts to fuse 'gruop_load_reduce' operations with 'group_store' operations.
|
||||
:param op: The current operation.
|
||||
:param next_op: The next operation to potentially merge with.
|
||||
:param tb: The thread block containing the operations.
|
||||
:param queue: The queue of operations.
|
||||
:return: True if operations are merged, False otherwise.
|
||||
"""
|
||||
if (
|
||||
next_op.inst == Instruction.group_store
|
||||
and same_count(op, next_op)
|
||||
and buf_dst_src_match(op, next_op)
|
||||
and same_chan_type(op, next_op)
|
||||
and not circular_dep_after_merge(op, next_op)
|
||||
and all_prevs_visited_after_merge(op, next_op)
|
||||
):
|
||||
# Append the destination chunk from next_op
|
||||
op.inst = Instruction.group_load_reduce_store
|
||||
op.src = next_op.src
|
||||
for dst in next_op.dsts:
|
||||
op.dsts.append(dst)
|
||||
# Merge operations
|
||||
merge_op(op, next_op)
|
||||
tb.ops.remove(next_op)
|
||||
queue.remove(next_op)
|
||||
return True
|
||||
return False
|
||||
|
||||
def try_remove_op(self, pending_remove_op: Op, condition: bool) -> bool:
|
||||
if condition:
|
||||
remove_op(pending_remove_op)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class DagOptimizer:
|
||||
|
||||
def __init__(self, instruction_dag: InstructionDAG):
|
||||
self.optimizer = _InstructionOptimizer()
|
||||
self.dag = instruction_dag
|
||||
|
||||
def remove_redundant_signal_wait(self):
|
||||
# For packet ops, we can remove signal/wait
|
||||
for rank, rank_tbs in enumerate(self.dag.tbs):
|
||||
for tbid, tb in rank_tbs.items():
|
||||
queue = list(tb.ops)
|
||||
while len(queue) > 0:
|
||||
op = queue[0]
|
||||
fused = False
|
||||
if op.inst == Instruction.put_packet:
|
||||
for next_op in op.next:
|
||||
fused = self.optimizer.try_remove_op(next_op, next_op.inst == Instruction.signal)
|
||||
if fused:
|
||||
break
|
||||
elif op.inst == Instruction.reduce_packet or op.inst == Instruction.copy_packet:
|
||||
for prev_op in op.prev:
|
||||
fused = self.optimizer.try_remove_op(prev_op, prev_op.inst == Instruction.wait)
|
||||
if fused:
|
||||
break
|
||||
if fused:
|
||||
continue
|
||||
queue = queue[1:]
|
||||
|
||||
def fuse_instructions(self):
|
||||
self._fuse_instructions_using_proxy_channel()
|
||||
self._fuse_same_instructions()
|
||||
self._optimize_rrcs_rs()
|
||||
self._optimize_group_ops()
|
||||
self._compact_instructions()
|
||||
|
||||
# put(src, sbuf, si, dst, dbuf, di) signal(src, sbuf, si, dst, dbuf, di)
|
||||
# -> putWithSignal(src, sbuf, si, dst, dbuf, di)
|
||||
# put(src, sbuf, si, dst, dbuf, di) signal(src, sbuf, si, dst, dbuf, di) flush(src, sbuf, si, dst, dbuf, di)
|
||||
# -> putWithSignalAndFlush(src, sbuf, si, dst, dbuf, di)
|
||||
def _fuse_instructions_using_proxy_channel(self):
|
||||
inst_followup_map = {
|
||||
Instruction.put: Instruction.signal,
|
||||
Instruction.put_with_signal: Instruction.flush,
|
||||
}
|
||||
for rank, rank_tbs in enumerate(self.dag.tbs):
|
||||
for tbid, tb in rank_tbs.items():
|
||||
queue = list(tb.ops)
|
||||
while len(queue) > 0:
|
||||
op = queue[0]
|
||||
fused = False
|
||||
if op.inst in inst_followup_map:
|
||||
for next_op in op.next:
|
||||
fused = self.optimizer.try_fuse_instructions_using_proxy_channel(
|
||||
op, next_op, tb, queue, inst_followup_map[op.inst]
|
||||
)
|
||||
if fused:
|
||||
break
|
||||
if fused:
|
||||
continue
|
||||
queue = queue[1:]
|
||||
|
||||
# rrc(_,_,_,dst,dbuf,di) rrc(_,_,_,dst,dbuf,di) -> rrc(list[src,sbuf,si], dst, dbuf, di)
|
||||
# signal(_,_,_,dst,dbuf,di) signal(_,_,_,dst,dbuf,di) -> signal(_,_,_,list[dst,dbuf,di])
|
||||
# wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_])
|
||||
# reduce(_,_,_,dst,dbuf,di) reduce(_,_,_,dst,dbuf,di) -> reduce(list[src,sbuf,si], dst, dbuf, di)
|
||||
# reduce_packet(_,_,_,dst,dbuf,di) reduce_packet(_,_,_,dst,dbuf,di) -> reduce_packet(list[src,sbuf,si], dst, dbuf, di)
|
||||
def _fuse_same_instructions(self):
|
||||
# Mapping instruction to their respective condition checks and same buffer function
|
||||
instruction_handlers = {
|
||||
Instruction.read_reduce_copy: same_buf_dst,
|
||||
Instruction.reduce: same_buf_dst,
|
||||
Instruction.reduce_packet: same_buf_dst,
|
||||
Instruction.signal: same_buf_src,
|
||||
Instruction.wait: same_buf_dst,
|
||||
}
|
||||
|
||||
for _, rank_tbs in enumerate(self.dag.tbs):
|
||||
for _, tb in rank_tbs.items():
|
||||
queue = list(tb.ops)
|
||||
while len(queue) > 0:
|
||||
op = queue[0]
|
||||
fused = False
|
||||
inst_type = op.inst
|
||||
if inst_type in instruction_handlers:
|
||||
for next_op in op.next:
|
||||
same_buf_func = instruction_handlers[inst_type]
|
||||
if self.optimizer.try_merge_same_instructions(
|
||||
op, next_op, tb, queue, inst_type, same_buf_func
|
||||
):
|
||||
fused = True
|
||||
break
|
||||
if fused:
|
||||
continue
|
||||
queue = queue[1:]
|
||||
|
||||
# rrc(_,_,_,dst,dbuf,di) put(dst,dbuf,di,_,_,_) -> rrcs(_,_,_,_,_,_)
|
||||
# reduce(_,_,_,dst,dbuf,di) put(dst,dbuf,di,_,_,_) -> rs(_,_,_,_,_,_)
|
||||
def _optimize_rrcs_rs(self):
|
||||
inst_types = [
|
||||
Instruction.read_reduce_copy,
|
||||
Instruction.reduce,
|
||||
Instruction.reduce_packet,
|
||||
Instruction.read_reduce_copy_send,
|
||||
Instruction.reduce_send,
|
||||
Instruction.reduce_send_packet,
|
||||
]
|
||||
for _, rank_tbs in enumerate(self.dag.tbs):
|
||||
for _, tb in rank_tbs.items():
|
||||
queue = list(tb.ops)
|
||||
while len(queue) > 0:
|
||||
op = queue[0]
|
||||
fused = False
|
||||
if op.inst in inst_types:
|
||||
for next_op in op.next:
|
||||
fused = self.optimizer.try_fuse_with_put(op, next_op, tb, queue)
|
||||
if fused:
|
||||
break
|
||||
if fused:
|
||||
continue
|
||||
queue = queue[1:]
|
||||
|
||||
# glre(srcs, sbuf, si, _, _, _), gstore (_, _, _, dsts, dbuf, di) -> glres(srcs, sbuf, si, dsts, dbuf, di)
|
||||
def _optimize_group_ops(self):
|
||||
inst_types = [
|
||||
Instruction.group_load_reduce,
|
||||
]
|
||||
for _, rank_tbs in enumerate(self.dag.tbs):
|
||||
for _, tb in rank_tbs.items():
|
||||
queue = list(tb.ops)
|
||||
while len(queue) > 0:
|
||||
op = queue[0]
|
||||
fused = False
|
||||
if op.inst in inst_types:
|
||||
for next_op in op.next:
|
||||
fused = self.optimizer.try_fuse_with_group_store(op, next_op, tb, queue)
|
||||
if fused:
|
||||
break
|
||||
if fused:
|
||||
continue
|
||||
queue = queue[1:]
|
||||
|
||||
# merge ops which are independent of other operations and no other operations in between
|
||||
# get(src, sbuf. si, dst, dbuf, di) get(src, sbuf, si, dst, dbuf, di) -> get(list[src,sbuf,si], list[dst,dbuf,di])
|
||||
# put(src, sbuf, si, dst, dbuf, di) put(src, sbuf, si, dst, dbuf, di) -> put(list[src,sbuf,si], list[dst,dbuf,di])
|
||||
# putWithSignal/putWithSignalAndFlush(src, sbuf, si, dst, dbuf, di)
|
||||
# putWithSignal/putWithSignalAndFlush(src, sbuf, si, dst, dbuf, di)
|
||||
# -> putWithSignal/putWithSignalAndFlush(list[src,sbuf,si], list[dst,dbuf,di])
|
||||
# wait(src,sbuf,si,_,_,_) wait(src,sbuf,si,_,_,_) -> wait(list[src,sbuf,si],_,_,_,_])
|
||||
def _compact_instructions(self):
|
||||
campactable_inst = [
|
||||
Instruction.get,
|
||||
Instruction.put,
|
||||
Instruction.put_packet,
|
||||
Instruction.put_with_signal,
|
||||
Instruction.put_with_signal_and_flush,
|
||||
Instruction.signal,
|
||||
Instruction.flush,
|
||||
Instruction.wait,
|
||||
]
|
||||
for _, rank_tbs in enumerate(self.dag.tbs):
|
||||
for _, tb in rank_tbs.items():
|
||||
if tb.id == -1:
|
||||
continue
|
||||
queue = list(tb.ops)
|
||||
while len(queue) > 0:
|
||||
op = queue[0]
|
||||
fused = False
|
||||
if op.inst in campactable_inst:
|
||||
fused = self.optimizer.try_compact_instructions(
|
||||
op, tb, queue, op.inst, same_src_dst_buffer_type
|
||||
)
|
||||
|
||||
if fused:
|
||||
continue
|
||||
queue = queue[1:]
|
||||
534
python/mscclpp/language/ir.py
Normal file
534
python/mscclpp/language/ir.py
Normal file
@@ -0,0 +1,534 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from dataclasses import asdict, dataclass
|
||||
import json
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from mscclpp.language.types import Buffer, ChannelType, Op, Program, Instruction
|
||||
|
||||
_local_src_insts_mscclpp: set = {
|
||||
Instruction.put,
|
||||
Instruction.put_packet,
|
||||
Instruction.signal,
|
||||
Instruction.flush,
|
||||
Instruction.put_with_signal,
|
||||
Instruction.put_with_signal_and_flush,
|
||||
Instruction.copy,
|
||||
Instruction.copy_packet,
|
||||
Instruction.transform_to_packet,
|
||||
Instruction.reduce,
|
||||
Instruction.reduce_packet,
|
||||
Instruction.reduce_send,
|
||||
Instruction.reduce_send_packet,
|
||||
Instruction.group_load_reduce_store,
|
||||
Instruction.group_store,
|
||||
}
|
||||
_local_dst_insts_mscclpp: set = {
|
||||
Instruction.get,
|
||||
Instruction.wait,
|
||||
Instruction.read_reduce_copy,
|
||||
Instruction.copy,
|
||||
Instruction.copy_packet,
|
||||
Instruction.transform_to_packet,
|
||||
Instruction.reduce,
|
||||
Instruction.read_reduce_copy_send,
|
||||
Instruction.reduce_send,
|
||||
Instruction.reduce_packet,
|
||||
Instruction.reduce_send_packet,
|
||||
Instruction.group_load_reduce_store,
|
||||
Instruction.group_load_reduce,
|
||||
}
|
||||
|
||||
_insts_no_need_sync_barrier: set = {
|
||||
Instruction.copy_packet,
|
||||
Instruction.reduce_packet,
|
||||
Instruction.reduce_send_packet,
|
||||
Instruction.barrier,
|
||||
}
|
||||
|
||||
|
||||
def ir_to_json(program: Program):
|
||||
# Figure out sizes of buffers based on usage
|
||||
buffer_sizes = defaultdict(lambda: 0)
|
||||
for gpu in program.gpus:
|
||||
for tb in gpu.threadblocks:
|
||||
for op in tb.ops:
|
||||
if op.inst in _local_src_insts_mscclpp:
|
||||
key = (gpu.rank, op.src.buffer)
|
||||
buffer_sizes[key] = max(buffer_sizes[key], op.src.index + op.src.size)
|
||||
for src in op.srcs:
|
||||
key = (gpu.rank, src.buffer)
|
||||
buffer_sizes[key] = max(buffer_sizes[key], src.index + src.size)
|
||||
if op.inst in _local_dst_insts_mscclpp:
|
||||
key = (gpu.rank, op.dst.buffer)
|
||||
buffer_sizes[key] = max(buffer_sizes[key], op.dst.index + op.dst.size)
|
||||
# ignore remote buffers
|
||||
if (
|
||||
op.inst != Instruction.read_reduce_copy_send
|
||||
and op.inst != Instruction.reduce_send
|
||||
and op.inst != Instruction.reduce_send_packet
|
||||
):
|
||||
for dst in op.dsts:
|
||||
key = (gpu.rank, dst.buffer)
|
||||
buffer_sizes[key] = max(buffer_sizes[key], dst.index + dst.size)
|
||||
for gpu in program.gpus:
|
||||
gpu.input_chunks = max(buffer_sizes[(gpu.rank, Buffer.input)], gpu.input_chunks)
|
||||
gpu.output_chunks = max(buffer_sizes[(gpu.rank, Buffer.output)], gpu.output_chunks)
|
||||
gpu.scratch_chunks = max(buffer_sizes[(gpu.rank, Buffer.scratch)], gpu.scratch_chunks)
|
||||
|
||||
# Since LL protocol will double the scratch size. We need to make sure all GPUs have the same scratch size.
|
||||
# Otherwise the offset calculation will be wrong.
|
||||
if program.protocol == "LL":
|
||||
max_scratch = max(gpu.scratch_chunks for gpu in program.gpus)
|
||||
for gpu in program.gpus:
|
||||
gpu.scratch_chunks = max_scratch
|
||||
|
||||
# get channel info for each GPU and threadblock
|
||||
for gpu in program.gpus:
|
||||
gpu.threadblocks = sorted(gpu.threadblocks, key=lambda tb: tb.id)
|
||||
chan_dict = {}
|
||||
# the channel key is the tuple (srcBuffer, dstBuffer, type)
|
||||
for tb in gpu.threadblocks:
|
||||
for ch in tb.channels:
|
||||
key = (ch.srcBuffer, ch.dstBuffer, ch.type)
|
||||
if key not in chan_dict:
|
||||
chan_dict[key] = [(tb.id, ch.connected_to)]
|
||||
else:
|
||||
chan_dict[key].append((tb.id, ch.connected_to))
|
||||
for key, value in chan_dict.items():
|
||||
chan_dict[key] = sorted(value)
|
||||
gpu.channels = chan_dict
|
||||
|
||||
# Remove the dependencies of wait after signal. They are actually depends on remote chunk
|
||||
for gpu in program.gpus:
|
||||
for tb in gpu.threadblocks:
|
||||
for op in tb.ops:
|
||||
if op.inst == Instruction.wait:
|
||||
op.depends = list(filter(lambda dep: dep.inst != Instruction.signal, op.depends))
|
||||
|
||||
# Filter out redundant dependencies
|
||||
# e.g. if op1 and op2 depend on op, and op1 happens before op2
|
||||
# then op2 does not need to explicitly depend on op
|
||||
for gpu in program.gpus:
|
||||
for tb in gpu.threadblocks:
|
||||
running_depends = []
|
||||
for op in tb.ops:
|
||||
op.depends = list(filter(lambda dep: dep not in running_depends, op.depends))
|
||||
running_depends = running_depends + op.depends
|
||||
|
||||
# Do some additional postprocessing of operations:
|
||||
# - Expand operations with dependencies with no-ops
|
||||
for gpu in program.gpus:
|
||||
for tb in gpu.threadblocks:
|
||||
new_ops = []
|
||||
for op in tb.ops:
|
||||
if op.inst in _insts_no_need_sync_barrier:
|
||||
new_ops.append(op)
|
||||
continue
|
||||
# Expand extra dependencies into nop operations
|
||||
nop = Op(Instruction.nop, -1, None, None, [])
|
||||
for i, dep in enumerate(op.depends):
|
||||
# barrier already syncs all threads
|
||||
if dep.inst != Instruction.barrier:
|
||||
nop.depends.append(dep)
|
||||
if len(new_ops) > 0 and (
|
||||
new_ops[-1].inst == Instruction.barrier or new_ops[-1].inst == Instruction.nop
|
||||
):
|
||||
new_ops[-1].depends.extend(nop.depends)
|
||||
elif len(nop.depends) > 0:
|
||||
new_ops.append(nop)
|
||||
new_ops.append(op)
|
||||
tb.ops = new_ops
|
||||
|
||||
# update step and tid for ops
|
||||
for gpu in program.gpus:
|
||||
for tb in gpu.threadblocks:
|
||||
for i, op in enumerate(tb.ops):
|
||||
op.step = i
|
||||
op.tb = tb.id
|
||||
|
||||
# Need to calculate channel info for each GPU
|
||||
nchannels = 0
|
||||
for gpu in program.gpus:
|
||||
max_tb_channels = 0
|
||||
if len(gpu.threadblocks) > 0:
|
||||
max_tb_channels = max(tb.channel + 1 for tb in gpu.threadblocks)
|
||||
nchannels = max(nchannels, max_tb_channels)
|
||||
return _dump_to_json(program)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _JsonInstruction:
|
||||
name: str
|
||||
i_buff: Optional[Dict[str, str]] = None
|
||||
i_cids: Optional[List[Dict[str, Union[int, List[int]]]]] = None
|
||||
o_buff: Optional[Dict[str, str]] = None
|
||||
o_cids: Optional[List[Dict[str, Union[int, List[int]]]]] = None
|
||||
src: Optional[int] = None
|
||||
srcs: Optional[List[Dict[str, Union[int, str]]]] = None
|
||||
srcbuff: Optional[str] = None
|
||||
srcoff: Optional[int] = None
|
||||
dst: Optional[int] = None
|
||||
dsts: Optional[List[Dict[str, Union[int, str]]]] = None
|
||||
dstbuff: Optional[str] = None
|
||||
dstoff: Optional[int] = None
|
||||
ctype: Optional[str] = None
|
||||
cnt: Optional[int] = None
|
||||
deps: Optional[List[Dict[str, int]]] = None
|
||||
nthread_blocks: Optional[int] = None
|
||||
barrier_id: Optional[int] = None
|
||||
|
||||
|
||||
class _OpConverter(ABC):
|
||||
def get_channel_ids(self, chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_type):
|
||||
channel_ids = []
|
||||
key = (src_buffer, dst_buffer, chan_type)
|
||||
if chan_type == ChannelType.nvls:
|
||||
ranks = []
|
||||
for c in chunk_list:
|
||||
ranks.append(c.rank)
|
||||
channel_ids.extend(
|
||||
[{"id": id} for id, ele in enumerate(tb_channel_dict[key]["connectedTo"]) if set(ele) == set(ranks)]
|
||||
)
|
||||
else:
|
||||
for c in chunk_list:
|
||||
channel_ids.extend(
|
||||
[
|
||||
{"id": id, "off": c.index}
|
||||
for id, ele in enumerate(tb_channel_dict[key]["connectedTo"])
|
||||
if ele == c.rank
|
||||
]
|
||||
)
|
||||
return channel_ids
|
||||
|
||||
@abstractmethod
|
||||
def to_json(self, op: Op) -> _JsonInstruction:
|
||||
pass
|
||||
|
||||
|
||||
class _SignalFlushConverter(_OpConverter):
|
||||
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
|
||||
dst_channel_ids = self.get_channel_ids(op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type)
|
||||
o_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value}
|
||||
return _JsonInstruction(
|
||||
name=op.inst.value,
|
||||
o_buff=o_buff,
|
||||
o_cids=dst_channel_ids,
|
||||
ctype=op.channel_type.value,
|
||||
cnt=op.cnt(),
|
||||
)
|
||||
|
||||
|
||||
class _WaitConverter(_OpConverter):
|
||||
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
|
||||
src_channel_ids = self.get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type)
|
||||
i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value}
|
||||
return _JsonInstruction(
|
||||
name=op.inst.value,
|
||||
i_buff=i_buff,
|
||||
i_cids=src_channel_ids,
|
||||
ctype=op.channel_type.value,
|
||||
cnt=op.cnt(),
|
||||
)
|
||||
|
||||
|
||||
class _ReadReduceCopyConverter(_OpConverter):
|
||||
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
|
||||
src_channel_ids = self.get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type)
|
||||
i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value}
|
||||
dst = op.dst
|
||||
src = op.dst # TODO(binyli): fix this
|
||||
return _JsonInstruction(
|
||||
name=op.inst.value,
|
||||
i_buff=i_buff,
|
||||
dst=dst.rank,
|
||||
dstbuff=dst.buffer.value,
|
||||
dstoff=dst.index,
|
||||
src=src.rank,
|
||||
srcbuff=src.buffer.value,
|
||||
srcoff=src.index,
|
||||
i_cids=src_channel_ids,
|
||||
ctype=op.channel_type.value,
|
||||
cnt=op.cnt(),
|
||||
)
|
||||
|
||||
|
||||
class _ReadReduceCopySendConverter(_OpConverter):
|
||||
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
|
||||
src_channel_ids = self.get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type)
|
||||
dst_channel_ids = self.get_channel_ids(
|
||||
op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, op.channel_type
|
||||
)
|
||||
i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value}
|
||||
o_buff = {"src": op.dst.buffer.value, "dst": op.dsts[0].buffer.value}
|
||||
dst = op.dst
|
||||
src = op.dst # TODO(binyli): fix this
|
||||
return _JsonInstruction(
|
||||
name=op.inst.value,
|
||||
i_buff=i_buff,
|
||||
i_cids=src_channel_ids,
|
||||
o_buff=o_buff,
|
||||
o_cids=dst_channel_ids,
|
||||
src=src.rank,
|
||||
srcbuff=src.buffer.value,
|
||||
srcoff=src.index,
|
||||
dst=dst.rank,
|
||||
dstbuff=dst.buffer.value,
|
||||
dstoff=dst.index,
|
||||
ctype=op.channel_type.value,
|
||||
cnt=op.cnt(),
|
||||
)
|
||||
|
||||
|
||||
class _ReduceSendConverter(_OpConverter):
|
||||
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
|
||||
dst_channel_ids = self.get_channel_ids(
|
||||
op.dsts, tb_channel_dict, op.dst.buffer, op.dsts[0].buffer, ChannelType.sm
|
||||
)
|
||||
o_buff = {"src": op.dst.buffer.value, "dst": op.dsts[0].buffer.value}
|
||||
srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs))
|
||||
dst = op.dst
|
||||
src = op.dst # TODO(binyli): fix this
|
||||
return _JsonInstruction(
|
||||
name=op.inst.value,
|
||||
o_buff=o_buff,
|
||||
o_cids=dst_channel_ids,
|
||||
src=src.rank,
|
||||
srcbuff=src.buffer.value,
|
||||
srcoff=src.index,
|
||||
srcs=srcs,
|
||||
dst=dst.rank,
|
||||
dstbuff=dst.buffer.value,
|
||||
dstoff=dst.index,
|
||||
ctype=op.channel_type.value,
|
||||
cnt=op.cnt(),
|
||||
)
|
||||
|
||||
|
||||
class _ReduceConverters(_OpConverter):
|
||||
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
|
||||
srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs))
|
||||
dst = op.dst
|
||||
src = op.dst
|
||||
return _JsonInstruction(
|
||||
name=op.inst.value,
|
||||
srcs=srcs,
|
||||
dst=dst.rank,
|
||||
dstbuff=dst.buffer.value,
|
||||
dstoff=dst.index,
|
||||
src=src.rank,
|
||||
srcbuff=src.buffer.value,
|
||||
srcoff=src.index,
|
||||
ctype=op.channel_type.value,
|
||||
cnt=op.cnt(),
|
||||
)
|
||||
|
||||
|
||||
class _NopConverter(_OpConverter):
|
||||
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
|
||||
return _JsonInstruction(
|
||||
name=op.inst.value,
|
||||
deps=list(map(lambda dep: {"tb": dep.tb, "step": dep.step}, op.depends)),
|
||||
)
|
||||
|
||||
|
||||
class _BarrierConverter(_OpConverter):
|
||||
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
|
||||
return _JsonInstruction(
|
||||
name=op.inst.value,
|
||||
nthread_blocks=len(op.extra["tb_list"]),
|
||||
barrier_id=op.extra["barrier_id"],
|
||||
)
|
||||
|
||||
|
||||
class _PutConverter(_OpConverter):
|
||||
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
|
||||
dst_channel_ids = self.get_channel_ids(op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type)
|
||||
o_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value}
|
||||
srcs = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.srcs))
|
||||
return _JsonInstruction(
|
||||
name=op.inst.value,
|
||||
o_buff=o_buff,
|
||||
o_cids=dst_channel_ids,
|
||||
srcs=srcs,
|
||||
ctype=op.channel_type.value,
|
||||
cnt=op.cnt(),
|
||||
)
|
||||
|
||||
|
||||
class _GetConverter(_OpConverter):
|
||||
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
|
||||
src_channel_ids = self.get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type)
|
||||
i_buff = {"src": op.src.buffer.value, "dst": op.dst.buffer.value}
|
||||
dsts = list(map(lambda x: {"buff": x.buffer.value, "off": x.index}, op.dsts))
|
||||
return _JsonInstruction(
|
||||
name=op.inst.value,
|
||||
i_buff=i_buff,
|
||||
i_cids=src_channel_ids,
|
||||
dsts=dsts,
|
||||
ctype=op.channel_type.value,
|
||||
cnt=op.cnt(),
|
||||
)
|
||||
|
||||
|
||||
class _CopyConverter(_OpConverter):
|
||||
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
|
||||
src = op.src
|
||||
dst = op.dst
|
||||
return _JsonInstruction(
|
||||
name=op.inst.value,
|
||||
src=src.rank,
|
||||
srcbuff=src.buffer.value,
|
||||
srcoff=src.index,
|
||||
dst=dst.rank,
|
||||
dstbuff=dst.buffer.value,
|
||||
dstoff=dst.index,
|
||||
ctype=op.channel_type.value,
|
||||
cnt=op.cnt(),
|
||||
)
|
||||
|
||||
|
||||
class _GroupLoadReduceStoreConverter(_OpConverter):
|
||||
def to_json(self, op: Op, tb_channel_dict: dict) -> _JsonInstruction:
|
||||
src = op.src
|
||||
dst = op.dst
|
||||
src_channel_ids = self.get_channel_ids(op.srcs, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type)
|
||||
dst_channel_ids = self.get_channel_ids(op.dsts, tb_channel_dict, op.src.buffer, op.dst.buffer, op.channel_type)
|
||||
return _JsonInstruction(
|
||||
name=op.inst.value,
|
||||
src=src.rank,
|
||||
srcbuff=src.buffer.value,
|
||||
srcoff=src.index,
|
||||
dst=dst.rank,
|
||||
dstbuff=dst.buffer.value,
|
||||
dstoff=dst.index,
|
||||
i_cids=src_channel_ids,
|
||||
o_cids=dst_channel_ids,
|
||||
ctype=op.channel_type.value,
|
||||
cnt=op.cnt(),
|
||||
)
|
||||
|
||||
|
||||
_json_converter_map: Dict[Instruction, _OpConverter] = {
|
||||
Instruction.signal: _SignalFlushConverter(),
|
||||
Instruction.flush: _SignalFlushConverter(),
|
||||
Instruction.wait: _WaitConverter(),
|
||||
Instruction.read_reduce_copy: _ReadReduceCopyConverter(),
|
||||
Instruction.read_reduce_copy_send: _ReadReduceCopySendConverter(),
|
||||
Instruction.reduce_send: _ReduceSendConverter(),
|
||||
Instruction.reduce_send_packet: _ReduceSendConverter(),
|
||||
Instruction.reduce: _ReduceConverters(),
|
||||
Instruction.reduce_packet: _ReduceConverters(),
|
||||
Instruction.nop: _NopConverter(),
|
||||
Instruction.barrier: _BarrierConverter(),
|
||||
Instruction.put: _PutConverter(),
|
||||
Instruction.put_packet: _PutConverter(),
|
||||
Instruction.put_with_signal: _PutConverter(),
|
||||
Instruction.put_with_signal_and_flush: _PutConverter(),
|
||||
Instruction.get: _GetConverter(),
|
||||
Instruction.copy: _CopyConverter(),
|
||||
Instruction.copy_packet: _CopyConverter(),
|
||||
Instruction.transform_to_packet: _CopyConverter(),
|
||||
Instruction.group_load_reduce_store: _GroupLoadReduceStoreConverter(),
|
||||
}
|
||||
|
||||
|
||||
def _dump_to_json(program: Program):
|
||||
gpus = []
|
||||
|
||||
def remove_empty_fields(d):
|
||||
return {k: v for k, v in d.items() if v not in [None, "", [], {}]}
|
||||
|
||||
max_scratch = max(gpu.scratch_chunks for gpu in program.gpus)
|
||||
max_input = max(gpu.input_chunks for gpu in program.gpus)
|
||||
max_output = max(gpu.output_chunks for gpu in program.gpus)
|
||||
|
||||
for id, gpu in enumerate(program.gpus):
|
||||
gpu_instance = {
|
||||
"id": id,
|
||||
"inputChunks": gpu.input_chunks,
|
||||
"outputChunks": gpu.output_chunks,
|
||||
"scratchChunks": gpu.scratch_chunks,
|
||||
"chunkGroups": program.num_chunk_groups,
|
||||
"threadblocks": [],
|
||||
"channels": [],
|
||||
}
|
||||
for (srcBuffer, dstBuffer, type), channels in gpu.channels.items():
|
||||
obj = {
|
||||
"srcbuff": srcBuffer.value if hasattr(srcBuffer, "value") else srcBuffer,
|
||||
"dstbuff": dstBuffer.value if hasattr(dstBuffer, "value") else dstBuffer,
|
||||
"type": type.value,
|
||||
"connectedTo": [ch[1] for ch in channels],
|
||||
}
|
||||
if type == ChannelType.nvls:
|
||||
obj["connectedTo"] = [sorted(list(peers)) for peers in obj["connectedTo"]]
|
||||
gpu_instance["channels"].append(obj)
|
||||
gpu_instance["channels"] = list(filter(lambda x: x["type"] != "none", gpu_instance["channels"]))
|
||||
gpu_instance["channels"] = sorted(gpu_instance["channels"], key=lambda x: (x["srcbuff"], x["dstbuff"]))
|
||||
|
||||
# render for GPU NVLS channels
|
||||
for i, chan in enumerate(gpu_instance["channels"]):
|
||||
if chan["type"] == "nvls":
|
||||
buff = chan["srcbuff"]
|
||||
buffer_size = (
|
||||
max_input
|
||||
if buff == Buffer.input.value
|
||||
else max_output if buff == Buffer.output.value else max_scratch
|
||||
)
|
||||
gpu_instance["channels"][i] = {
|
||||
"buff": chan["srcbuff"],
|
||||
"type": chan["type"],
|
||||
"rankGroups": [{"size": buffer_size, "ranks": ranks} for ranks in chan["connectedTo"]],
|
||||
}
|
||||
|
||||
for tb in gpu.threadblocks:
|
||||
if tb.id < 0:
|
||||
continue
|
||||
ops = []
|
||||
tb_channels = []
|
||||
tb_channel_dict = {}
|
||||
for (srcBuffer, dstBuffer, type), channels in gpu.channels.items():
|
||||
obj = {
|
||||
"srcbuff": srcBuffer.value if hasattr(srcBuffer, "value") else srcBuffer,
|
||||
"dstbuff": dstBuffer.value if hasattr(dstBuffer, "value") else dstBuffer,
|
||||
"type": type.value,
|
||||
"chanIds": [id for id, ele in enumerate(channels) if ele[0] == tb.id],
|
||||
"connectedTo": [ele[1] for ele in channels if ele[0] == tb.id],
|
||||
}
|
||||
if len(obj["chanIds"]) > 0:
|
||||
tb_channel_dict[(srcBuffer, dstBuffer, type)] = obj
|
||||
tb_channels.append(obj)
|
||||
tb_channels = filter(lambda x: x["type"] != "none", tb_channels)
|
||||
tb_channels = sorted(tb_channels, key=lambda x: (x["srcbuff"], x["dstbuff"]))
|
||||
for op in tb.ops:
|
||||
if op.tb == -1:
|
||||
continue
|
||||
instr = _json_converter_map[op.inst].to_json(op, tb_channel_dict)
|
||||
ops.append(remove_empty_fields(asdict(instr)))
|
||||
threadblock = {
|
||||
"id": tb.id,
|
||||
"ops": ops,
|
||||
"channels": list(
|
||||
map(
|
||||
lambda x: {"src": x["srcbuff"], "dst": x["dstbuff"], "ctype": x["type"], "cids": x["chanIds"]},
|
||||
tb_channels,
|
||||
)
|
||||
),
|
||||
}
|
||||
gpu_instance["threadblocks"].append(threadblock)
|
||||
gpus.append(gpu_instance)
|
||||
obj = {
|
||||
"name": program.name,
|
||||
"collective": program.collective,
|
||||
"protocol": program.protocol,
|
||||
"inplace": program.inplace,
|
||||
"gpus": gpus,
|
||||
"num_threads_per_block": program.num_threads_per_block,
|
||||
"use_double_scratch_buffer": program.use_double_scratch_buffer,
|
||||
"min_message_size": program.min_message_size,
|
||||
"max_message_size": program.max_message_size,
|
||||
}
|
||||
return json.dumps(obj, indent=2)
|
||||
433
python/mscclpp/language/program.py
Normal file
433
python/mscclpp/language/program.py
Normal file
@@ -0,0 +1,433 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from mscclpp.language.collectives import Collective
|
||||
from mscclpp.language.buffer import *
|
||||
from mscclpp.language.types import ChannelType, ChunkRef, ReplicationPolicy, Threadblock
|
||||
from mscclpp.language.ir import *
|
||||
from mscclpp.language.dag import DagOptimizer, DagLower, InstructionDAG
|
||||
from mscclpp.language.rank import Rank
|
||||
|
||||
_current_program = None
|
||||
|
||||
|
||||
def _curr():
|
||||
global _current_program
|
||||
if _current_program == None:
|
||||
raise RuntimeError("No Program in context")
|
||||
return _current_program
|
||||
|
||||
|
||||
# For msccl++ program, we have one assumption that for channel can be identified by (send_buffer, recv_buffer, type, send_tb/recv_tb)
|
||||
# which means the send_tb and recv_tb should be the same for a pair of signal and wait, also same for put/get operation.
|
||||
# If one sender what to send data to peer want to use different tb in receiver side. We need to send to same tb in receiver side first,
|
||||
# then performance a across tb sync. This is a limitation of current implementation.
|
||||
class MSCCLPPProgram:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
collective: Collective,
|
||||
num_ranks: int,
|
||||
instances: int,
|
||||
protocol: str = "Simple",
|
||||
instr_fusion: bool = True,
|
||||
replication_policy: ReplicationPolicy = ReplicationPolicy.duplicated,
|
||||
num_threads_per_block: int = 1024,
|
||||
use_double_scratch_buffer: bool = False,
|
||||
min_message_size: int = 0,
|
||||
max_message_size: int = 2**64 - 1,
|
||||
):
|
||||
self.name = name
|
||||
self.collective = collective
|
||||
self.num_ranks = num_ranks
|
||||
self.instances = instances
|
||||
self.protocol = protocol
|
||||
self.instr_fusion = instr_fusion
|
||||
self.replication_policy = replication_policy
|
||||
self.num_threads_per_block = num_threads_per_block
|
||||
self.use_double_scratch_buffer = use_double_scratch_buffer
|
||||
self.min_message_size = min_message_size
|
||||
self.max_message_size = max_message_size
|
||||
assert protocol == "Simple" or protocol == "LL", f"Given protocol: {protocol}. Must be either Simple, LL"
|
||||
self.run_opt = True # Runs optimization passes
|
||||
# Initialize the input buffers
|
||||
self.buffers = collective.init_buffers()
|
||||
self.instr_dag = InstructionDAG(self.num_ranks, self.buffers)
|
||||
self.ranks = []
|
||||
for r in range(self.num_ranks):
|
||||
self.ranks.append(Rank(r))
|
||||
for index, chunk in enumerate(self.buffers[r][Buffer.input]):
|
||||
buffer, index = self.collective.get_buffer_index(r, Buffer.input, index)
|
||||
ref = self.get_ref(r, buffer, index, 1)
|
||||
# self.chunk_dag.init_chunk(chunk, ref)
|
||||
self.instr_dag.add_start(r, buffer, index, ref)
|
||||
|
||||
def __enter__(self):
|
||||
global _current_program
|
||||
if _current_program != None:
|
||||
raise RuntimeError("There is already a MSCCLPP Program in context")
|
||||
_current_program = self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
global _current_program
|
||||
if _current_program != self:
|
||||
raise RuntimeError("This program is not currently in context")
|
||||
_current_program = None
|
||||
|
||||
def _convert_to_execution_plan(self):
|
||||
ops = self.instr_dag.convert_set_list()
|
||||
ops = sorted(ops, key=lambda x: x.step)
|
||||
for op in ops:
|
||||
rank = op.rank
|
||||
tbid = op.tb
|
||||
if tbid not in self.instr_dag.tbs[rank]:
|
||||
self.instr_dag.tbs[rank][tbid] = Threadblock(id=tbid)
|
||||
tb = self.instr_dag.tbs[rank][tbid]
|
||||
tb.ops.append(op)
|
||||
|
||||
def get_rank_ref(self, rank):
|
||||
return RankRef(rank, self)
|
||||
|
||||
# Tracks a send operation on the buffers
|
||||
def apply_send(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size):
|
||||
src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index)
|
||||
dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index)
|
||||
sb = self.buffers[src][src_buffer]
|
||||
db = self.buffers[dst][dst_buffer]
|
||||
for i in range(size):
|
||||
db[dst_index + i] = sb[src_index + i]
|
||||
|
||||
# Tracks a reduce operation on the buffers
|
||||
def apply_reduce(self, src, src_buffer, src_index, dst, dst_buffer, dst_index, size):
|
||||
src_buffer, src_index = self.collective.get_buffer_index(src, src_buffer, src_index)
|
||||
dst_buffer, dst_index = self.collective.get_buffer_index(dst, dst_buffer, dst_index)
|
||||
sb = self.buffers[src][src_buffer]
|
||||
db = self.buffers[dst][dst_buffer]
|
||||
for i in range(size):
|
||||
reduce_chunk = db[dst_index + i]
|
||||
sent_chunk = sb[src_index + i]
|
||||
db[dst_index + i] = reduce_chunk.reduce(dst, sent_chunk)
|
||||
|
||||
def get_ref(self, rank, buffer, index, size):
|
||||
buffer, index = self.collective.get_buffer_index(rank, buffer, index)
|
||||
return Ref(rank, buffer, index, size, self)
|
||||
|
||||
def get_chunks(self, rank, buffer, index, size=1):
|
||||
chunks = [None] * size
|
||||
for i in range(0, size):
|
||||
if self.buffers[rank][buffer] and index + i < len(self.buffers[rank][buffer]):
|
||||
chunks[i] = self.buffers[rank][buffer][index + i]
|
||||
else:
|
||||
chunks[i] = None
|
||||
return chunks
|
||||
|
||||
def check_buffer_exists(self, rank, name):
|
||||
if name not in self.buffers[rank]:
|
||||
self.buffers[rank][name] = BufferSlice(Buffer.scratch, name)
|
||||
|
||||
# Checks that all chunks that should be on each rank
|
||||
# are present in the output buffer.
|
||||
def check(self):
|
||||
return self.collective.check(self)
|
||||
|
||||
# Lower program to MSCCLPP
|
||||
def lower(self):
|
||||
self._convert_to_execution_plan()
|
||||
self.instr_dag.complete_channels()
|
||||
dag_optimizer = DagOptimizer(self.instr_dag)
|
||||
dag_optimizer.remove_redundant_signal_wait()
|
||||
if self.instr_fusion:
|
||||
dag_optimizer.fuse_instructions()
|
||||
dag_lower = DagLower(self.instr_dag)
|
||||
gpu_prgms = dag_lower.lower(self.instances, self.replication_policy)
|
||||
program = Program(
|
||||
self.name,
|
||||
self.collective.name,
|
||||
self.collective.inplace,
|
||||
self.protocol,
|
||||
gpu_prgms,
|
||||
self.collective.num_chunk_groups * self.instances,
|
||||
self.num_threads_per_block,
|
||||
self.use_double_scratch_buffer,
|
||||
self.min_message_size,
|
||||
self.max_message_size,
|
||||
)
|
||||
for gpu in program.gpus:
|
||||
gpu.input_chunks = len(self.buffers[gpu.rank][Buffer.input]) * self.instances
|
||||
if not self.collective.inplace:
|
||||
gpu.output_chunks = len(self.buffers[gpu.rank][Buffer.output]) * self.instances
|
||||
return program
|
||||
|
||||
def generate_json(self):
|
||||
return ir_to_json(self.lower())
|
||||
|
||||
|
||||
def Json():
|
||||
print(_curr().generate_json())
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankRef:
|
||||
rank: int
|
||||
prog: MSCCLPPProgram
|
||||
|
||||
def _get_barrier_id(self, tb_list) -> int:
|
||||
return self.prog.ranks[self.rank].get_barrier_id(tb_list)
|
||||
|
||||
def barrier(self, tb_list):
|
||||
barrier_id = self._get_barrier_id(tb_list)
|
||||
return self.prog.instr_dag.add_barrier(self.rank, tb_list, barrier_id)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Ref(ChunkRef):
|
||||
prog: MSCCLPPProgram
|
||||
|
||||
def __repr__(self):
|
||||
return f"Ref(Buffer:{self.buffer}, Index:{self.index}, Size:{self.size}, Rank:{self.rank})"
|
||||
|
||||
def _end(self):
|
||||
return self.index + self.size
|
||||
|
||||
def _get_chunk(self, index):
|
||||
return self.prog.buffers[self.rank][self.buffer][index]
|
||||
|
||||
def split(self, num):
|
||||
assert self.size % num == 0, f"Trying to split a chunk of {self.size} elements into {num} parts"
|
||||
chunks = [None] * num
|
||||
size = self.size // num
|
||||
for i in range(num):
|
||||
index = self.index + i * size
|
||||
chunks[i] = self.prog.get_ref(self.rank, self.buffer, index, size)
|
||||
return chunks
|
||||
|
||||
def group(self, other):
|
||||
assert self.rank == other.rank, f"Trying to concatenate chunks on ranks {self.rank} and {other.rank}"
|
||||
assert self.buffer == other.buffer, f"Trying to concatenate chunks in {self.buffer} and {other.buffer}"
|
||||
if self.index < other.index:
|
||||
first = self
|
||||
second = other
|
||||
else:
|
||||
first = other
|
||||
second = self
|
||||
|
||||
end = max(first._end(), second._end())
|
||||
return Ref(self.rank, self.buffer, first.index, end - first.index, self.prog)
|
||||
|
||||
def _get_buffer_index(self, remote_rank, buffer, index):
|
||||
if index == -1 and buffer == None:
|
||||
return self.buffer, self.index
|
||||
elif index == -1 and buffer is not Buffer.input and buffer is not Buffer.output:
|
||||
return buffer, self.prog.buffers[remote_rank][buffer].instance_size()
|
||||
return buffer, index
|
||||
|
||||
def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, use_packet=False):
|
||||
self.prog.check_buffer_exists(dst, buffer)
|
||||
assert self.rank != dst, "Cannot put to the same rank"
|
||||
buffer, index = self._get_buffer_index(dst, buffer, index)
|
||||
|
||||
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
|
||||
self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size)
|
||||
if use_packet:
|
||||
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, True)
|
||||
self.prog.instr_dag.add_signal(self.rank, self, dst_chunkref, -1, ChannelType.none)
|
||||
self.prog.instr_dag.add_wait(dst, dst_chunkref, self, -1, ChannelType.none)
|
||||
else:
|
||||
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type)
|
||||
return dst_chunkref
|
||||
|
||||
def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm):
|
||||
return self._put(dst, buffer, index, sendtb, chan_type)
|
||||
|
||||
def put_packet(
|
||||
self,
|
||||
dst,
|
||||
buffer=None,
|
||||
index=-1,
|
||||
sendtb=-1,
|
||||
chan_type=ChannelType.sm,
|
||||
temp_buffer=None,
|
||||
temp_buffer_index=-1,
|
||||
):
|
||||
chunk_ref = self
|
||||
if chan_type == ChannelType.proxy:
|
||||
assert temp_buffer is not None, "Need to specify a temporary buffer for proxy channels"
|
||||
chunk_ref = self._copy(
|
||||
self.rank, temp_buffer, temp_buffer_index, sendtb, trans_from_packet=False, trans_to_packet=True
|
||||
)
|
||||
return chunk_ref._put(dst, buffer, index, sendtb, chan_type, True)
|
||||
|
||||
def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm):
|
||||
self.prog.check_buffer_exists(src, buffer)
|
||||
sender = src
|
||||
receiver = self.rank
|
||||
assert sender != receiver, "Cannot get from the same rank"
|
||||
buffer, index = self._get_buffer_index(src, buffer, index)
|
||||
|
||||
src_chunkref = self.prog.get_ref(src, buffer, index, self.size)
|
||||
|
||||
self.prog.apply_send(src, buffer, index, self.rank, self.buffer, self.index, self.size)
|
||||
self.prog.instr_dag.add_get(receiver, src_chunkref, self, recvtb, chan_type)
|
||||
|
||||
# for signal and wait, currently we assuem the pair will use the same tb index. In future we need
|
||||
# to infer the tb index from the instruction DAG Add a channel is define as (send_tb, src_buffer, recv_tb, dst_buffer, type).
|
||||
# Then we can use DAG info to reduce the number of channels.
|
||||
def signal(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm):
|
||||
sender = self.rank
|
||||
receiver = dst
|
||||
assert sender != receiver, "Cannot signal to the same rank"
|
||||
buffer, index = self._get_buffer_index(dst, buffer, index)
|
||||
|
||||
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
|
||||
self.prog.instr_dag.add_signal(sender, self, dst_chunkref, sendtb, chan_type)
|
||||
|
||||
# only proxy channel need to use this function
|
||||
def flush(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.proxy):
|
||||
assert chan_type == ChannelType.proxy, "Only proxy channel can use flush"
|
||||
sender = self.rank
|
||||
receiver = dst
|
||||
assert sender != receiver, "Cannot flush to the same rank"
|
||||
buffer, index = self._get_buffer_index(dst, buffer, index)
|
||||
|
||||
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
|
||||
self.prog.instr_dag.add_flush(sender, self, dst_chunkref, sendtb)
|
||||
|
||||
def wait(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm):
|
||||
sender = src
|
||||
receiver = self.rank
|
||||
assert sender != receiver, "Cannot wait on the same rank"
|
||||
buffer, index = self._get_buffer_index(src, buffer, index)
|
||||
|
||||
src_chunkref = self.prog.get_ref(src, buffer, index, self.size)
|
||||
self.prog.instr_dag.add_wait(receiver, self, src_chunkref, recvtb, chan_type)
|
||||
|
||||
def _copy(self, dst, buffer=None, index=-1, sendtb=-1, trans_from_packet=False, trans_to_packet=False):
|
||||
self.prog.check_buffer_exists(dst, buffer)
|
||||
buffer, index = self._get_buffer_index(dst, buffer, index)
|
||||
|
||||
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
|
||||
# Check if we are copying the chunk to the same index (easy mistake when we are using inplace)
|
||||
if dst_chunkref == self:
|
||||
return
|
||||
self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size)
|
||||
|
||||
assert self.rank == dst, "Chunk copy only supports intra-rank communication"
|
||||
self.prog.instr_dag.add_copy(self.rank, self, dst_chunkref, sendtb, trans_from_packet, trans_to_packet)
|
||||
|
||||
return dst_chunkref
|
||||
|
||||
# Copies the chunk(s) referenced by this chunkref onto Rank dst at location (buffer, index)
|
||||
def copy(self, dst, buffer=None, index=-1, sendtb=-1):
|
||||
return self._copy(dst, buffer, index, sendtb)
|
||||
|
||||
def copy_packet(self, dst, buffer=None, index=-1, sendtb=-1):
|
||||
return self._copy(dst, buffer, index, sendtb, trans_from_packet=True, trans_to_packet=False)
|
||||
|
||||
def _reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.sm, use_packet=False):
|
||||
dst = self.rank
|
||||
src = other_chunkref.rank
|
||||
|
||||
self.prog.apply_reduce(
|
||||
src, other_chunkref.buffer, other_chunkref.index, dst, self.buffer, self.index, self.size
|
||||
)
|
||||
if use_packet:
|
||||
assert src == dst, "Packet reduce only supports intra-rank communication"
|
||||
|
||||
if src != dst:
|
||||
self.prog.instr_dag.add_read_reduce(dst, other_chunkref, self, recvtb, channel_type)
|
||||
else:
|
||||
self.prog.instr_dag.add_reduce(src, other_chunkref, self, recvtb, use_packet)
|
||||
|
||||
return self
|
||||
|
||||
# Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref
|
||||
def reduce(self, other_chunkref, recvtb=-1, channel_type=ChannelType.sm):
|
||||
return self._reduce(other_chunkref, recvtb, channel_type)
|
||||
|
||||
# Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref
|
||||
def reduce_packet(self, other_chunkref, recvtb=-1):
|
||||
return self._reduce(other_chunkref, recvtb, use_packet=True)
|
||||
|
||||
# """
|
||||
# Group operations. These operations are used to perform collective operations across multiple chunks.
|
||||
# For now, all chunks must has the same buffer type and offset.
|
||||
# """
|
||||
# Reads the chunk(s) referenced by other_chunkref and reduce into the chunk referenced by this chunkref
|
||||
def group_load_reduce(self, other_chunkrefs: list, recvtb=-1, chan_type=ChannelType.nvls):
|
||||
assert (
|
||||
len(other_chunkrefs) > 0 and chan_type == ChannelType.nvls
|
||||
), "Group load reduce only supports nvls channel"
|
||||
nranks_per_node = self.prog.collective.num_ranks_per_node
|
||||
for other_chunkref in other_chunkrefs:
|
||||
assert (
|
||||
self.rank // nranks_per_node == other_chunkref.rank // nranks_per_node
|
||||
), "Group load reduce only supports chunks on the same node"
|
||||
assert self.buffer == other_chunkref.buffer, "Group load reduce only supports chunks with the same buffer"
|
||||
assert self.index == other_chunkref.index, "Group load reduce only supports chunks with the same index"
|
||||
|
||||
src_chunkref = other_chunkref
|
||||
self.prog.apply_reduce(
|
||||
src_chunkref.rank,
|
||||
src_chunkref.buffer,
|
||||
src_chunkref.index,
|
||||
self.rank,
|
||||
self.buffer,
|
||||
self.index,
|
||||
self.size,
|
||||
)
|
||||
self.prog.instr_dag.add_group_load_reduce(self.rank, other_chunkrefs, self, recvtb, chan_type)
|
||||
return self
|
||||
|
||||
# Copies the chunk(s) referenced by this chunkref onto other_chunkrefs
|
||||
def group_store(self, dsts: list, index=-1, buffer=None, sendtb=-1, chan_type=ChannelType.nvls):
|
||||
for dst in dsts:
|
||||
self.prog.check_buffer_exists(dst, buffer)
|
||||
assert index == -1 or self.index == index, "Group store only supports chunks with the same index"
|
||||
assert chan_type == ChannelType.nvls, "Group store only supports nvls channel"
|
||||
|
||||
other_chunkrefs = []
|
||||
nrank_per_node = self.prog.collective.num_ranks_per_node
|
||||
for dst in dsts:
|
||||
# Direct linked
|
||||
buffer, index = self._get_buffer_index(dst, buffer, index)
|
||||
|
||||
assert self.buffer == buffer, "Group store only supports chunks with the same buffer"
|
||||
assert (
|
||||
self.rank // nrank_per_node == dst // nrank_per_node
|
||||
), "Group store only supports chunks on the same node"
|
||||
|
||||
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
|
||||
self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size)
|
||||
other_chunkrefs.append(dst_chunkref)
|
||||
# add new op here
|
||||
self.prog.instr_dag.add_group_store(self.rank, self, other_chunkrefs, sendtb, chan_type)
|
||||
|
||||
def get_origin_index(self, index=0):
|
||||
return self._get_chunk(index + self.index).origin_index
|
||||
|
||||
def get_origin_rank(self, index=0):
|
||||
return self._get_chunk(index + self.index).origin_rank
|
||||
|
||||
def get_dst_index(self, index=0):
|
||||
return self._get_chunk(index + self.index).dst_index
|
||||
|
||||
def get_dst_rank(self, index=0):
|
||||
return self._get_chunk(index + self.index).dst_rank
|
||||
|
||||
def print_chunk_info(self, index=0):
|
||||
print(self._get_chunk(index + self.index))
|
||||
|
||||
|
||||
def chunk(rank, buffer, index, size=1) -> Ref:
|
||||
if _curr().buffers[rank][buffer][index] is None:
|
||||
return None
|
||||
return _curr().get_ref(rank, buffer, index, size)
|
||||
|
||||
|
||||
def rank(rank) -> RankRef:
|
||||
return _curr().get_rank_ref(rank)
|
||||
|
||||
|
||||
def Check():
|
||||
return _curr().check()
|
||||
33
python/mscclpp/language/rank.py
Normal file
33
python/mscclpp/language/rank.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class BarrierInfo:
|
||||
def __init__(self, tb_list):
|
||||
self.tb_list = tb_list
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.tb_list == other.tb_list
|
||||
|
||||
def __hash__(self):
|
||||
return hash(tuple(self.tb_list))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Rank:
|
||||
rank_id: int
|
||||
current_max_barrier_id: int = 0
|
||||
current_barriers: Dict[BarrierInfo, int] = field(default_factory=dict)
|
||||
|
||||
def get_barrier_id(self, tb_list):
|
||||
barrier_info = BarrierInfo(tb_list)
|
||||
if barrier_info in self.current_barriers:
|
||||
return self.current_barriers[barrier_info]
|
||||
else:
|
||||
self.current_barriers[barrier_info] = self.current_max_barrier_id
|
||||
barrier_id = self.current_max_barrier_id
|
||||
self.current_max_barrier_id += 1
|
||||
return barrier_id
|
||||
173
python/mscclpp/language/types.py
Normal file
173
python/mscclpp/language/types.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Union, List
|
||||
|
||||
from mscclpp.language.buffer import Buffer
|
||||
|
||||
|
||||
@dataclass
|
||||
class Gpu:
|
||||
rank: int
|
||||
threadblocks: list = field(default_factory=list)
|
||||
|
||||
# From ncclize
|
||||
precopies: list = field(default_factory=list)
|
||||
postcopies: list = field(default_factory=list)
|
||||
inputs: dict = field(default_factory=dict)
|
||||
outputs: dict = field(default_factory=dict)
|
||||
input_chunks: int = 0
|
||||
output_chunks: int = 0
|
||||
scratch_chunks: int = 0
|
||||
scratch: dict = field(default_factory=dict)
|
||||
channels: dict = field(default_factory=dict)
|
||||
|
||||
def scratch_size(self):
|
||||
return max((idx for addr, idx in self.scratch.items()), default=-1) + 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class Program:
|
||||
name: str
|
||||
collective: str
|
||||
inplace: bool
|
||||
protocol: str
|
||||
gpus: List[Gpu] = field(default_factory=list)
|
||||
num_chunk_groups: int = 1
|
||||
num_threads_per_block: int = 1024
|
||||
use_double_scratch_buffer: bool = False
|
||||
min_message_size: int = 0
|
||||
max_message_size: int = 2**64 - 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class Threadblock:
|
||||
channel: int = -1
|
||||
send: int = -1
|
||||
recv: int = -1
|
||||
ops: list = field(default_factory=list)
|
||||
rbid: int = -1 # threadblock id of the receiver
|
||||
id: int = -1
|
||||
channels: list = field(default_factory=list)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self is other
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
|
||||
class ReplicationPolicy(Enum):
|
||||
# this means each instance deal with the different chunk
|
||||
# Chunk A, Chunk B -> Chunk A0, Chunk B0, Chunk A1, Chunk B1
|
||||
duplicated = "duplicated"
|
||||
# this means each instance deal with the different chunk in interleaved way
|
||||
# Chunk A, Chunk B -> Chunk A0, Chunk A1, Chunk B0, Chunk B1
|
||||
interleaved = "interleaved"
|
||||
# this means pack multi instrances to deal with the same chunk and share the channels
|
||||
packed = "packed"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class Instruction(Enum):
|
||||
start = "start"
|
||||
nop = "nop"
|
||||
read_reduce_copy = "rrc"
|
||||
read_reduce_copy_send = "rrcs"
|
||||
reduce_send = "rs"
|
||||
copy = "copy"
|
||||
reduce = "reduce"
|
||||
copy_packet = "cpkt"
|
||||
transform_to_packet = "tpkt"
|
||||
reduce_send_packet = "rspkt"
|
||||
reduce_packet = "rpkt"
|
||||
put = "put"
|
||||
put_packet = "ppkt"
|
||||
put_with_signal = "pws"
|
||||
put_with_signal_and_flush = "pwsf"
|
||||
get = "get"
|
||||
wait = "wait"
|
||||
signal = "signal"
|
||||
flush = "flush"
|
||||
barrier = "barrier"
|
||||
group_store = "gstore"
|
||||
group_load_reduce = "glre"
|
||||
group_load_reduce_store = "glres"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkRef:
|
||||
rank: int
|
||||
buffer: Buffer
|
||||
index: int
|
||||
size: int
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.rank, self.buffer, self.index, self.size))
|
||||
|
||||
|
||||
class ChannelType(Enum):
|
||||
proxy = "proxy"
|
||||
sm = "sm"
|
||||
none = "none"
|
||||
nvls = "nvls"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Channel:
|
||||
srcBuffer: Buffer
|
||||
dstBuffer: Buffer
|
||||
type: ChannelType
|
||||
connected_to: Union[int, List[int]]
|
||||
|
||||
def __hash__(self):
|
||||
# Ensure connected_to is converted to a tuple if it's a list
|
||||
connected_to_hashable = tuple(self.connected_to) if isinstance(self.connected_to, list) else self.connected_to
|
||||
return hash((self.srcBuffer, self.dstBuffer, self.type, connected_to_hashable))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Op:
|
||||
inst: Instruction
|
||||
rank: int
|
||||
src: ChunkRef
|
||||
dst: ChunkRef
|
||||
depends: list = field(default_factory=list)
|
||||
step: int = -1 # Step in the TB
|
||||
tb: int = -1 # TB this op is assigned to
|
||||
prev: list = field(default_factory=list) # List of instructions that happen before
|
||||
next: list = field(default_factory=list) # List of instructions that happen after
|
||||
channel: int = -1
|
||||
channel_type: ChannelType = ChannelType.none
|
||||
srcs: list = field(default_factory=list)
|
||||
dsts: list = field(default_factory=list)
|
||||
extra: dict = field(default_factory=dict)
|
||||
|
||||
def cnt(self):
|
||||
if self.src:
|
||||
if self.dst:
|
||||
assert self.src.size == self.dst.size
|
||||
return self.src.size
|
||||
elif self.dst:
|
||||
return self.dst.size
|
||||
else:
|
||||
return 0
|
||||
|
||||
def __eq__(self, other):
|
||||
return self is other
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Op({self.inst}, {self.rank}, {self.src}, {self.dst}, step:{self.step}, tb:{self.tb})"
|
||||
94
python/mscclpp/language/utils.py
Normal file
94
python/mscclpp/language/utils.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from mscclpp.language.types import Op
|
||||
|
||||
|
||||
def remove_op(op: Op):
|
||||
for p in op.prev:
|
||||
p.next.remove(op)
|
||||
p.next += op.next
|
||||
p.next = list(set(p.next))
|
||||
|
||||
for n in op.next:
|
||||
n.prev.remove(op)
|
||||
n.prev = op.prev.union(n.prev)
|
||||
|
||||
op.next = []
|
||||
op.prev = []
|
||||
|
||||
|
||||
def merge_op(op: Op, other_op: Op):
|
||||
if other_op in op.next:
|
||||
op.next.remove(other_op)
|
||||
other_op.prev.remove(op)
|
||||
for p in other_op.prev:
|
||||
p.next.remove(other_op)
|
||||
p.next.append(op)
|
||||
|
||||
for n in other_op.next:
|
||||
n.prev.remove(other_op)
|
||||
n.prev.add(op)
|
||||
|
||||
op.prev = op.prev.union(other_op.prev)
|
||||
op.next = list(set(op.next + other_op.next))
|
||||
|
||||
|
||||
def circular_dep_after_merge(op: Op, other_op: Op):
|
||||
root = set([op, other_op])
|
||||
frontier = set(op.next)
|
||||
if other_op in frontier:
|
||||
frontier.remove(other_op)
|
||||
frontier = list(frontier.union(other_op.next))
|
||||
while len(frontier) > 0:
|
||||
current = frontier[0]
|
||||
for n in current.next:
|
||||
# The root node will be visited again if there is a circular dependency
|
||||
if n in root:
|
||||
return True
|
||||
frontier.append(n)
|
||||
frontier = frontier[1:]
|
||||
|
||||
|
||||
def all_prevs_visited_after_merge(op: Op, other_op: Op):
|
||||
"""
|
||||
For case: op2.prev = [op1, op3]. op1.next = [op2]. op3.next = [op2]. And op1 and op2 are satisfied to merge.
|
||||
We only apply the merge if all previous ops of op2 are visited. (op1 is the last previous op of op2).
|
||||
"""
|
||||
step = op.step
|
||||
for prev in other_op.prev:
|
||||
if prev.step > step:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def same_tb(op1: Op, op2: Op):
|
||||
return op1.tb == op2.tb and op1.channel == op2.channel
|
||||
|
||||
|
||||
def same_count(op1: Op, op2: Op):
|
||||
return op1.cnt() == op2.cnt()
|
||||
|
||||
|
||||
def same_buf_dst(op1: Op, op2: Op):
|
||||
return op1.dst.buffer == op2.dst.buffer and op1.dst.index == op2.dst.index
|
||||
|
||||
|
||||
def same_src_dst_buffer_type(op1: Op, op2: Op):
|
||||
return op1.src.buffer == op2.src.buffer and op1.dst.buffer == op2.dst.buffer
|
||||
|
||||
|
||||
def buf_dst_src_match(op1: Op, op2: Op):
|
||||
return op1.dst.buffer == op2.src.buffer and op1.dst.index == op2.src.index
|
||||
|
||||
|
||||
def same_buf_src(op1: Op, op2: Op):
|
||||
return op1.src.buffer == op2.src.buffer and op1.src.index == op2.src.index
|
||||
|
||||
|
||||
def same_chan_type(op1: Op, op2: Op):
|
||||
return op1.channel_type == op2.channel_type
|
||||
|
||||
|
||||
def same_tb(op1: Op, op2: Op):
|
||||
return op1.tb == op2.tb
|
||||
34
python/test/configs/mscclpp_lang_test_config.json
Normal file
34
python/test/configs/mscclpp_lang_test_config.json
Normal file
@@ -0,0 +1,34 @@
|
||||
[
|
||||
{
|
||||
"filename": "allgather_barrier.py",
|
||||
"args": ["8", "8"]
|
||||
},
|
||||
{
|
||||
"filename": "allreduce_allpairs_packet.py",
|
||||
"args": ["8", "8"]
|
||||
},
|
||||
{
|
||||
"filename": "allreduce_allpairs_get.py",
|
||||
"args": ["8", "8"]
|
||||
},
|
||||
{
|
||||
"filename": "allreduce_allpairs.py",
|
||||
"args": ["8", "8"]
|
||||
},
|
||||
{
|
||||
"filename": "allreduce_ring.py",
|
||||
"args": ["8", "8"]
|
||||
},
|
||||
{
|
||||
"filename": "send_recv_packet.py",
|
||||
"args": ["2"]
|
||||
},
|
||||
{
|
||||
"filename": "send_recv_proxy.py",
|
||||
"args": ["2"]
|
||||
},
|
||||
{
|
||||
"filename": "allreduce_nvls.py",
|
||||
"args": ["8", "2"]
|
||||
}
|
||||
]
|
||||
47
python/test/test_generate_mscclpp_lang_result.py
Normal file
47
python/test/test_generate_mscclpp_lang_result.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
|
||||
|
||||
def run_examples(input_folder, configs, output_folder):
|
||||
for config in configs:
|
||||
file_name = config["filename"]
|
||||
args = config["args"]
|
||||
|
||||
input_file_path = Path(input_folder) / file_name
|
||||
# Strip the ".py" from the filename and add ".output"
|
||||
base_file_name = file_name[:-3] if file_name.endswith(".py") else file_name
|
||||
base_file_name = base_file_name.replace("/", "_")
|
||||
output_file_path = Path(output_folder) / f"{base_file_name}.output"
|
||||
|
||||
# Construct the command to run the Python script
|
||||
command = ["python3", str(input_file_path)] + args
|
||||
|
||||
# Run the command and capture output
|
||||
with open(output_file_path, "w") as output_file:
|
||||
result = subprocess.run(command, stdout=output_file, stderr=subprocess.STDOUT, text=True)
|
||||
|
||||
# Optional: Check the return code to handle errors
|
||||
if result.returncode != 0:
|
||||
print(f"Error running {file_name}. See {output_file_path} for details.")
|
||||
|
||||
|
||||
def main(input_folder, config_path, output_folder):
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
run_examples(input_folder, config, output_folder)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Process files according to a configuration and save the results.")
|
||||
parser.add_argument("input_folder", type=str, help="Path to the folder containing the input files.")
|
||||
parser.add_argument("config", type=str, help="Path to the configuration file.")
|
||||
parser.add_argument("output_folder", type=str, help="Path to the folder where the processed files will be saved.")
|
||||
args = parser.parse_args()
|
||||
main(args.input_folder, args.config, args.output_folder)
|
||||
Reference in New Issue
Block a user