mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 17:26:04 +00:00
Support Fusion for ReadPutPacket Operation at DSL (#742)
Support is being added for fusing the ReadPutPacket operation on DSL, which reduces the overhead caused by reading packet data multiple times in the scratch buffer. Fusion will occur when two rppkt operations are executed consecutively with the same src_buffer: rppkt(src, dst0) + rppkt(src, dst1) -> rppkt(src, [dst0, dst1] Co-authored-by: Binyang Li <binyli@microsoft.com>
This commit is contained in:
@@ -534,6 +534,7 @@ class PutOperation(BaseOperation):
|
||||
self.dst_buff = dst_buff
|
||||
self.channel_ids = channel_ids
|
||||
self.channel_type = channel_type
|
||||
self.from_packet = from_packet
|
||||
self.to_packet = to_packet
|
||||
self.with_signal = with_signal
|
||||
self.with_signal_and_flush = with_signal_and_flush
|
||||
@@ -579,6 +580,25 @@ class PutOperation(BaseOperation):
|
||||
with_signal=self.with_signal,
|
||||
with_signal_and_flush=self.with_signal_and_flush,
|
||||
)
|
||||
elif (
|
||||
isinstance(other, PutOperation)
|
||||
and self.name == Instruction.read_put_packet
|
||||
and self.name == other.name
|
||||
and self.src_buff == other.src_buff
|
||||
and self.channel_type == other.channel_type
|
||||
and self.tbg_info == other.tbg_info
|
||||
):
|
||||
fused_operation = PutOperation(
|
||||
src_buff=self.src_buff,
|
||||
dst_buff=self.dst_buff + other.dst_buff,
|
||||
channel_ids=self.channel_ids + other.channel_ids,
|
||||
channel_type=self.channel_type,
|
||||
tbg_info=self.tbg_info,
|
||||
from_packet=self.from_packet,
|
||||
to_packet=self.to_packet,
|
||||
with_signal=self.with_signal,
|
||||
with_signal_and_flush=self.with_signal_and_flush,
|
||||
)
|
||||
|
||||
return fused_operation
|
||||
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
from mscclpp.language.channel import *
|
||||
from mscclpp.language.rank import *
|
||||
from mscclpp.language.general import *
|
||||
from mscclpp.language.program import *
|
||||
from mscclpp.language.collectives import *
|
||||
|
||||
|
||||
def allgather_example(name, gpu_size, num_threads_per_block, min_message_size, max_message_size):
|
||||
chunksperloop = 1
|
||||
collective = AllGather(gpu_size, chunksperloop, True)
|
||||
with CollectiveProgram(
|
||||
name,
|
||||
collective,
|
||||
gpu_size,
|
||||
protocol="LL",
|
||||
num_threads_per_block=num_threads_per_block,
|
||||
use_double_scratch_buffer=True,
|
||||
min_message_size=min_message_size,
|
||||
max_message_size=max_message_size,
|
||||
):
|
||||
# Creating Scratch Buffers
|
||||
scratch_buffer = []
|
||||
for gpu in range(gpu_size):
|
||||
scratch_buffer.append(Buffer(gpu, 2 * gpu_size))
|
||||
|
||||
# Copying it to scratch buffer
|
||||
for gpu in range(gpu_size):
|
||||
rank = Rank(gpu)
|
||||
scratch_offset = gpu_size
|
||||
input_buffer = rank.get_input_buffer()
|
||||
rank.copy_packets(
|
||||
scratch_buffer[gpu][scratch_offset + gpu : scratch_offset + gpu + 1], input_buffer[0:1], tb=0
|
||||
)
|
||||
|
||||
# Putting packets in the remote scratch buffer
|
||||
for gpu in range(gpu_size):
|
||||
rank = Rank(gpu)
|
||||
output_buffer = rank.get_output_buffer()
|
||||
for peer in range(1, gpu_size):
|
||||
dst_rank = (gpu + peer) % gpu_size
|
||||
ch = MemoryChannel(dst_rank, gpu)
|
||||
tb = 0
|
||||
ch.read_put_packets(
|
||||
scratch_buffer[dst_rank][gpu : gpu + 1],
|
||||
scratch_buffer[gpu][scratch_offset + gpu : scratch_offset + gpu + 1],
|
||||
tb,
|
||||
)
|
||||
|
||||
# Copying packets from local scratch buffer to local buffer
|
||||
for gpu in range(gpu_size):
|
||||
rank = Rank(gpu)
|
||||
output_buffer = rank.get_output_buffer()
|
||||
for peer in range(1, gpu_size):
|
||||
dst_rank = (gpu + peer) % gpu_size
|
||||
rank.unpack_packets(
|
||||
output_buffer[dst_rank : dst_rank + 1],
|
||||
scratch_buffer[gpu][dst_rank : dst_rank + 1],
|
||||
tb=0,
|
||||
)
|
||||
|
||||
print(JSON())
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--name", type=str, help="name of the program")
|
||||
parser.add_argument("--num_gpus", type=int, help="number of gpus")
|
||||
parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block")
|
||||
parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size")
|
||||
parser.add_argument("--max_message_size", type=int, default=2**64 - 1, help="maximum message size")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
allgather_example(args.name, args.num_gpus, args.num_threads_per_block, args.min_message_size, args.max_message_size)
|
||||
Reference in New Issue
Block a user