mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-03 21:21:25 +00:00
58 lines
1.6 KiB
Python
58 lines
1.6 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import argparse
|
|
from mscclpp.language import *
|
|
from mscclpp.language.collectives import SendRecv
|
|
from mscclpp.language.buffer import Buffer
|
|
from mscclpp.language.types import ChannelType
|
|
|
|
|
|
def send_recv(instances):
|
|
"""
|
|
Send and receive data between two ranks using port channels, with LL protocol and double scratch buffer.
|
|
Steps:
|
|
1. Each rank sends a chunk to every other rank's scratch buffer with packet format via port channel.
|
|
2. Wait for the data to be received, then copy it to the output buffer.
|
|
"""
|
|
size = 2
|
|
chunksperloop = 1
|
|
collective = SendRecv(size, chunksperloop, False)
|
|
with MSCCLPPProgram(
|
|
"send_recv",
|
|
collective,
|
|
size,
|
|
instances,
|
|
protocol="LL",
|
|
use_double_scratch_buffer=True,
|
|
):
|
|
for r in range(size):
|
|
for nghr in range(size):
|
|
if nghr == r:
|
|
continue
|
|
c = chunk(r, Buffer.input, 0)
|
|
c.put_packet(
|
|
nghr,
|
|
Buffer.scratch,
|
|
1,
|
|
sendtb=0,
|
|
chan_type=ChannelType.port,
|
|
temp_buffer=Buffer.scratch,
|
|
temp_buffer_index=0,
|
|
)
|
|
|
|
for r in range(size):
|
|
c = chunk(r, Buffer.scratch, 1)
|
|
c.copy_packet(r, Buffer.output, 0, sendtb=0)
|
|
|
|
Json()
|
|
Check()
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("instances", type=int, help="number of instances")
|
|
|
|
args = parser.parse_args()
|
|
|
|
send_recv(args.instances)
|