Files
mscclpp/python/examples/allreduce_allpairs_packet.py
Caio Rocha 8a564977e5 Updating MSCCLLang Examples (#462)
Co-authored-by: Caio Rocha <aiorocha@microsoft.com>
2025-02-19 09:48:31 -08:00

70 lines
2.5 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
from mscclpp.language import *
from mscclpp.language.collectives import AllReduce
from mscclpp.language.buffer import Buffer
def allreduce_allpairs(gpus, instances):
"""
AllReduce with all pairs algorithm using packets format.
Steps:
1. Each rank sends its nth chunk to the nth rank's scratch space.
2. Each rank performs a local reduction on its nth chunk using data from all other ranks' scratch spaces.
3. Each rank sends the reduced data to all other ranks' scratch spaces.
4. Each rank retrieves the final reduced result from the scratch space.
"""
size = gpus
chunksperloop = gpus * gpus
collective = AllReduce(size, chunksperloop, True)
with MSCCLPPProgram(
"allreduce_packets",
collective,
size,
instances,
protocol="LL",
use_double_scratch_buffer=True,
):
# Each rank sends the nth chunk to the nth rank into scratch space
for r1 in range(size):
for tb in range(size):
if tb == r1:
continue
remote_rank = tb
index = remote_rank * size
c = chunk(r1, Buffer.input, index, size)
c.put_packet(remote_rank, Buffer.scratch, index=r1 * size, sendtb=tb)
# Each rank performs a local reduction on the nth chunk
# Utilize 8 threadblocks for this reduction for better parallelism
for r in range(size):
for index in range(size):
c = chunk(r, Buffer.input, r * size + index)
for peer in range(size):
if peer != r:
c.reduce_packet(chunk(r, Buffer.scratch, peer * size + index), recvtb=index)
for peer in range(size):
if peer != r:
c.put_packet(peer, Buffer.scratch, (size * size) + r * size + index, sendtb=index)
# Each rank get final result from scratch space
for r in range(size):
for peer in range(size):
if peer != r:
c = chunk(r, Buffer.scratch, size * size + peer * size, size)
c.copy_packet(r, Buffer.input, peer * size, sendtb=peer)
Json()
Check()
parser = argparse.ArgumentParser()
parser.add_argument("num_gpus", type=int, help="number of gpus")
parser.add_argument("instances", type=int, help="number of instances")
args = parser.parse_args()
allreduce_allpairs(args.num_gpus, args.instances)