Files
mscclpp/python/examples/allreduce_allpairs_get.py
Binyang Li af0bb86e07 Merge mscclpp-lang to mscclpp project (#442)
First step to merge msccl-tools into mscclpp repo. In this step will
move all msccl related code, pass the current tests and do some
necessary refactor.

Add `mscclpp.language` module
Add `_InstructionOptimizer` and `DagOptimizer` class to optimize the dag
Add `DagLower` to lower dag to intermediate representation 
Add documents for mscclpp.language
Remove msccl related code
2025-01-22 09:47:37 -08:00

79 lines
2.9 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
from mscclpp.language import *
from mscclpp.language.collectives import AllReduce
from mscclpp.language.buffer import Buffer
def allreduce_allpairs(gpus, instances):
"""
AllReduce with all pairs algorithm using get semantics.
Steps:
1. Sync all ranks to ensure the data is ready.
2. Each rank read chunks from all peers and reduces the data.
3. Signal all ranks to notify that the data is ready.
4. Wait for all chunks to be ready, then retrieve the chunks from all peers.
"""
size = gpus
chunksperloop = gpus * gpus
collective = AllReduce(size, chunksperloop, True)
with MSCCLPPProgram(
"allreduce_pairs",
collective,
size,
instances,
protocol="Simple",
):
# Each rank sends the nth chunk to the nth rank into scratch space
for rank in range(size):
for tb in range(size):
index = rank * size
c = chunk(rank, Buffer.input, index + tb)
# make sure the data is ready
for nghr in range(size):
peer_index = nghr * size
if rank != nghr:
c_peer = chunk(rank, Buffer.input, peer_index + tb)
c_peer.signal(nghr, Buffer.input, peer_index + tb, sendtb=tb)
for nghr in range(size):
if rank != nghr:
c.wait(nghr, Buffer.input, index + tb, recvtb=tb)
# reduce the chunks
for i in range(size):
nghr = (rank + i) % size
if rank != nghr:
c.reduce(chunk(nghr, Buffer.input, index + tb), recvtb=tb)
for nghr in range(size):
if rank != nghr:
c.signal(nghr, Buffer.input, index + tb, sendtb=tb)
# wait for all the chunks is ready, then get the chunks
for rank in range(size):
for tb in range(size):
for nghr in range(size):
if rank != nghr:
index = nghr * size
c = chunk(rank, Buffer.input, index + tb)
c.wait(nghr, Buffer.input, index + tb, recvtb=tb)
for i in range(size):
nghr = (rank + i) % size
index = nghr * size
if rank != nghr:
c = chunk(rank, Buffer.input, index + tb)
c.get(nghr, Buffer.input, index + tb, recvtb=tb)
Json()
Check()
parser = argparse.ArgumentParser()
parser.add_argument("num_gpus", type=int, help="number of gpus")
parser.add_argument("instances", type=int, help="number of instances")
args = parser.parse_args()
allreduce_allpairs(args.num_gpus, args.instances)