mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 17:26:04 +00:00
First step to merge msccl-tools into mscclpp repo. In this step will move all msccl related code, pass the current tests and do some necessary refactor. Add `mscclpp.language` module Add `_InstructionOptimizer` and `DagOptimizer` class to optimize the dag Add `DagLower` to lower dag to intermediate representation Add documents for mscclpp.language Remove msccl related code
79 lines
2.9 KiB
Python
79 lines
2.9 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import argparse
|
|
from mscclpp.language import *
|
|
from mscclpp.language.collectives import AllReduce
|
|
from mscclpp.language.buffer import Buffer
|
|
|
|
|
|
def allreduce_allpairs(gpus, instances):
|
|
"""
|
|
AllReduce with all pairs algorithm using get semantics.
|
|
Steps:
|
|
1. Sync all ranks to ensure the data is ready.
|
|
2. Each rank read chunks from all peers and reduces the data.
|
|
3. Signal all ranks to notify that the data is ready.
|
|
4. Wait for all chunks to be ready, then retrieve the chunks from all peers.
|
|
"""
|
|
size = gpus
|
|
chunksperloop = gpus * gpus
|
|
collective = AllReduce(size, chunksperloop, True)
|
|
with MSCCLPPProgram(
|
|
"allreduce_pairs",
|
|
collective,
|
|
size,
|
|
instances,
|
|
protocol="Simple",
|
|
):
|
|
|
|
# Each rank sends the nth chunk to the nth rank into scratch space
|
|
for rank in range(size):
|
|
for tb in range(size):
|
|
index = rank * size
|
|
c = chunk(rank, Buffer.input, index + tb)
|
|
# make sure the data is ready
|
|
for nghr in range(size):
|
|
peer_index = nghr * size
|
|
if rank != nghr:
|
|
c_peer = chunk(rank, Buffer.input, peer_index + tb)
|
|
c_peer.signal(nghr, Buffer.input, peer_index + tb, sendtb=tb)
|
|
for nghr in range(size):
|
|
if rank != nghr:
|
|
c.wait(nghr, Buffer.input, index + tb, recvtb=tb)
|
|
# reduce the chunks
|
|
for i in range(size):
|
|
nghr = (rank + i) % size
|
|
if rank != nghr:
|
|
c.reduce(chunk(nghr, Buffer.input, index + tb), recvtb=tb)
|
|
for nghr in range(size):
|
|
if rank != nghr:
|
|
c.signal(nghr, Buffer.input, index + tb, sendtb=tb)
|
|
|
|
# wait for all the chunks is ready, then get the chunks
|
|
for rank in range(size):
|
|
for tb in range(size):
|
|
for nghr in range(size):
|
|
if rank != nghr:
|
|
index = nghr * size
|
|
c = chunk(rank, Buffer.input, index + tb)
|
|
c.wait(nghr, Buffer.input, index + tb, recvtb=tb)
|
|
for i in range(size):
|
|
nghr = (rank + i) % size
|
|
index = nghr * size
|
|
if rank != nghr:
|
|
c = chunk(rank, Buffer.input, index + tb)
|
|
c.get(nghr, Buffer.input, index + tb, recvtb=tb)
|
|
|
|
Json()
|
|
Check()
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("num_gpus", type=int, help="number of gpus")
|
|
parser.add_argument("instances", type=int, help="number of instances")
|
|
|
|
args = parser.parse_args()
|
|
|
|
allreduce_allpairs(args.num_gpus, args.instances)
|