mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-06-07 00:05:19 +00:00
Updating MSCCLLang Examples (#462)
Co-authored-by: Caio Rocha <aiorocha@microsoft.com>
This commit is contained in:
@@ -35,7 +35,7 @@ def allreduce_allpairs(gpus, instances):
|
||||
remote_rank = tb
|
||||
index = remote_rank * size
|
||||
c = chunk(r1, Buffer.input, index, size)
|
||||
c.put_packet(remote_rank, "scratch", index=r1 * size, sendtb=tb)
|
||||
c.put_packet(remote_rank, Buffer.scratch, index=r1 * size, sendtb=tb)
|
||||
|
||||
# Each rank performs a local reduction on the nth chunk
|
||||
# Utilize 8 threadblocks for this reduction for better parallelism
|
||||
@@ -44,16 +44,16 @@ def allreduce_allpairs(gpus, instances):
|
||||
c = chunk(r, Buffer.input, r * size + index)
|
||||
for peer in range(size):
|
||||
if peer != r:
|
||||
c.reduce_packet(chunk(r, "scratch", peer * size + index), recvtb=index)
|
||||
c.reduce_packet(chunk(r, Buffer.scratch, peer * size + index), recvtb=index)
|
||||
for peer in range(size):
|
||||
if peer != r:
|
||||
c.put_packet(peer, "scratch", (size * size) + r * size + index, sendtb=index)
|
||||
c.put_packet(peer, Buffer.scratch, (size * size) + r * size + index, sendtb=index)
|
||||
|
||||
# Each rank get final result from scratch space
|
||||
for r in range(size):
|
||||
for peer in range(size):
|
||||
if peer != r:
|
||||
c = chunk(r, "scratch", size * size + peer * size, size)
|
||||
c = chunk(r, Buffer.scratch, size * size + peer * size, size)
|
||||
c.copy_packet(r, Buffer.input, peer * size, sendtb=peer)
|
||||
|
||||
Json()
|
||||
|
||||
@@ -33,16 +33,16 @@ def send_recv(instances):
|
||||
c = chunk(r, Buffer.input, 0)
|
||||
c.put_packet(
|
||||
nghr,
|
||||
"scratch",
|
||||
Buffer.scratch,
|
||||
1,
|
||||
sendtb=0,
|
||||
chan_type=ChannelType.port,
|
||||
temp_buffer="scratch",
|
||||
temp_buffer=Buffer.scratch,
|
||||
temp_buffer_index=0,
|
||||
)
|
||||
|
||||
for r in range(size):
|
||||
c = chunk(r, "scratch", 1)
|
||||
c = chunk(r, Buffer.scratch, 1)
|
||||
c.copy_packet(r, Buffer.output, 0, sendtb=0)
|
||||
|
||||
Json()
|
||||
|
||||
@@ -31,16 +31,16 @@ def send_recv(instances):
|
||||
c = chunk(r, Buffer.input, 0)
|
||||
c.put(
|
||||
nghr,
|
||||
"scratch",
|
||||
Buffer.scratch,
|
||||
1,
|
||||
sendtb=0,
|
||||
chan_type=ChannelType.port,
|
||||
)
|
||||
c.signal(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.port)
|
||||
c.flush(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.port)
|
||||
c.signal(nghr, Buffer.scratch, 1, sendtb=0, chan_type=ChannelType.port)
|
||||
c.flush(nghr, Buffer.scratch, 1, sendtb=0, chan_type=ChannelType.port)
|
||||
|
||||
for r in range(size):
|
||||
c = chunk(r, "scratch", 1)
|
||||
c = chunk(r, Buffer.scratch, 1)
|
||||
c.wait(1 - r, Buffer.input, 0, recvtb=0, chan_type=ChannelType.port)
|
||||
c.copy(r, Buffer.output, 0, sendtb=0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user