mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
70 lines
2.5 KiB
Python
70 lines
2.5 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import argparse
|
|
from mscclpp.language import *
|
|
from mscclpp.language.collectives import AllReduce
|
|
from mscclpp.language.buffer import Buffer
|
|
|
|
|
|
def allreduce_allpairs(gpus, instances):
|
|
"""
|
|
AllReduce with all pairs algorithm using packets format.
|
|
Steps:
|
|
1. Each rank sends its nth chunk to the nth rank's scratch space.
|
|
2. Each rank performs a local reduction on its nth chunk using data from all other ranks' scratch spaces.
|
|
3. Each rank sends the reduced data to all other ranks' scratch spaces.
|
|
4. Each rank retrieves the final reduced result from the scratch space.
|
|
"""
|
|
size = gpus
|
|
chunksperloop = gpus * gpus
|
|
collective = AllReduce(size, chunksperloop, True)
|
|
with MSCCLPPProgram(
|
|
"allreduce_packets",
|
|
collective,
|
|
size,
|
|
instances,
|
|
protocol="LL",
|
|
use_double_scratch_buffer=True,
|
|
):
|
|
# Each rank sends the nth chunk to the nth rank into scratch space
|
|
for r1 in range(size):
|
|
for tb in range(size):
|
|
if tb == r1:
|
|
continue
|
|
remote_rank = tb
|
|
index = remote_rank * size
|
|
c = chunk(r1, Buffer.input, index, size)
|
|
c.put_packet(remote_rank, Buffer.scratch, index=r1 * size, sendtb=tb)
|
|
|
|
# Each rank performs a local reduction on the nth chunk
|
|
# Utilize 8 threadblocks for this reduction for better parallelism
|
|
for r in range(size):
|
|
for index in range(size):
|
|
c = chunk(r, Buffer.input, r * size + index)
|
|
for peer in range(size):
|
|
if peer != r:
|
|
c.reduce_packet(chunk(r, Buffer.scratch, peer * size + index), recvtb=index)
|
|
for peer in range(size):
|
|
if peer != r:
|
|
c.put_packet(peer, Buffer.scratch, (size * size) + r * size + index, sendtb=index)
|
|
|
|
# Each rank get final result from scratch space
|
|
for r in range(size):
|
|
for peer in range(size):
|
|
if peer != r:
|
|
c = chunk(r, Buffer.scratch, size * size + peer * size, size)
|
|
c.copy_packet(r, Buffer.input, peer * size, sendtb=peer)
|
|
|
|
Json()
|
|
Check()
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("num_gpus", type=int, help="number of gpus")
|
|
parser.add_argument("instances", type=int, help="number of instances")
|
|
|
|
args = parser.parse_args()
|
|
|
|
allreduce_allpairs(args.num_gpus, args.instances)
|