Merge mscclpp-lang to mscclpp project (#442)

First step to merge msccl-tools into mscclpp repo. In this step will
move all msccl related code, pass the current tests and do some
necessary refactor.

Add `mscclpp.language` module
Add `_InstructionOptimizer` and `DagOptimizer` class to optimize the dag
Add `DagLower` to lower dag to intermediate representation 
Add documents for mscclpp.language
Remove msccl related code
This commit is contained in:
Binyang Li
2025-01-22 09:47:37 -08:00
committed by GitHub
parent 4ee15b7ad0
commit af0bb86e07
28 changed files with 3417 additions and 18 deletions

View File

@@ -0,0 +1,59 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
from mscclpp.language import *
from mscclpp.language.collectives import AllReduce
from mscclpp.language.buffer import Buffer
def allreduce_ring(size, instances):
"""
Implements a ring based allreduce.
Steps:
1. Send signal to next rank and wait for signal from previous rank. Make sure the data is ready in previous rank.
2. Reduce the data and send to next rank.
3. After all the data is reduced, propagate the data to all the ranks.
"""
collective = AllReduce(size, size, True)
with MSCCLPPProgram(
f"allreduce_ring",
collective,
size,
instances,
protocol="Simple",
):
# Reduce ring
for step in range(0, size - 1):
for index in range(0, size):
rank = (index + step) % size
next_rank = (index + step + 1) % size
c = chunk(rank, Buffer.input, index)
c.signal(next_rank, Buffer.input, index, 0)
prev_rank = (index + step - 1) % size
c = chunk(rank, Buffer.input, (index + size - 1) % size)
c.wait(prev_rank, Buffer.input, (index + size - 1) % size, 0)
c.reduce(chunk(prev_rank, Buffer.input, (index + size - 1) % size), recvtb=0)
# Propagate ring
for step in range(-1, size - 2):
for index in range(0, size):
rank = (index + step) % size
c = chunk(rank, Buffer.input, index)
next_rank = (index + step + 1) % size
c.put(next_rank, Buffer.input, index, sendtb=0)
c.signal(next_rank, Buffer.input, index, 0)
prev_rank = (index + step - 1) % size
c = chunk(rank, Buffer.input, (index + size - 1) % size)
c.wait(prev_rank, Buffer.input, (index + size - 1) % size, 0)
Json()
Check()
parser = argparse.ArgumentParser()
parser.add_argument("num_gpus", type=int, help="number of gpus")
parser.add_argument("instances", type=int, help="number of instances")
args = parser.parse_args()
allreduce_ring(args.num_gpus, args.instances)