Files
mscclpp/python/examples/allreduce_ring.py
Binyang Li af0bb86e07 Merge mscclpp-lang to mscclpp project (#442)
First step to merge msccl-tools into mscclpp repo. In this step will
move all msccl related code, pass the current tests and do some
necessary refactor.

Add `mscclpp.language` module
Add `_InstructionOptimizer` and `DagOptimizer` class to optimize the dag
Add `DagLower` to lower dag to intermediate representation 
Add documents for mscclpp.language
Remove msccl related code
2025-01-22 09:47:37 -08:00

60 lines
2.2 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
from mscclpp.language import *
from mscclpp.language.collectives import AllReduce
from mscclpp.language.buffer import Buffer
def allreduce_ring(size, instances):
"""
Implements a ring based allreduce.
Steps:
1. Send signal to next rank and wait for signal from previous rank. Make sure the data is ready in previous rank.
2. Reduce the data and send to next rank.
3. After all the data is reduced, propagate the data to all the ranks.
"""
collective = AllReduce(size, size, True)
with MSCCLPPProgram(
f"allreduce_ring",
collective,
size,
instances,
protocol="Simple",
):
# Reduce ring
for step in range(0, size - 1):
for index in range(0, size):
rank = (index + step) % size
next_rank = (index + step + 1) % size
c = chunk(rank, Buffer.input, index)
c.signal(next_rank, Buffer.input, index, 0)
prev_rank = (index + step - 1) % size
c = chunk(rank, Buffer.input, (index + size - 1) % size)
c.wait(prev_rank, Buffer.input, (index + size - 1) % size, 0)
c.reduce(chunk(prev_rank, Buffer.input, (index + size - 1) % size), recvtb=0)
# Propagate ring
for step in range(-1, size - 2):
for index in range(0, size):
rank = (index + step) % size
c = chunk(rank, Buffer.input, index)
next_rank = (index + step + 1) % size
c.put(next_rank, Buffer.input, index, sendtb=0)
c.signal(next_rank, Buffer.input, index, 0)
prev_rank = (index + step - 1) % size
c = chunk(rank, Buffer.input, (index + size - 1) % size)
c.wait(prev_rank, Buffer.input, (index + size - 1) % size, 0)
Json()
Check()
parser = argparse.ArgumentParser()
parser.add_argument("num_gpus", type=int, help="number of gpus")
parser.add_argument("instances", type=int, help="number of instances")
args = parser.parse_args()
allreduce_ring(args.num_gpus, args.instances)