Support Fusion for ReadPutPacket Operation at DSL (#742)

Support is being added for fusing the ReadPutPacket operation on DSL,
which reduces the overhead caused by reading packet data multiple times
in the scratch buffer. Fusion will occur when two rppkt operations are
executed consecutively with the same src_buffer:

rppkt(src, dst0) + rppkt(src, dst1) -> rppkt(src, [dst0, dst1]

Co-authored-by: Binyang Li <binyli@microsoft.com>
This commit is contained in:
Caio Rocha
2026-02-12 17:27:20 -08:00
committed by GitHub
parent 42be3660e0
commit dff3bc7bbb
4 changed files with 105 additions and 9 deletions

View File

@@ -534,6 +534,7 @@ class PutOperation(BaseOperation):
self.dst_buff = dst_buff
self.channel_ids = channel_ids
self.channel_type = channel_type
self.from_packet = from_packet
self.to_packet = to_packet
self.with_signal = with_signal
self.with_signal_and_flush = with_signal_and_flush
@@ -579,6 +580,25 @@ class PutOperation(BaseOperation):
with_signal=self.with_signal,
with_signal_and_flush=self.with_signal_and_flush,
)
elif (
isinstance(other, PutOperation)
and self.name == Instruction.read_put_packet
and self.name == other.name
and self.src_buff == other.src_buff
and self.channel_type == other.channel_type
and self.tbg_info == other.tbg_info
):
fused_operation = PutOperation(
src_buff=self.src_buff,
dst_buff=self.dst_buff + other.dst_buff,
channel_ids=self.channel_ids + other.channel_ids,
channel_type=self.channel_type,
tbg_info=self.tbg_info,
from_packet=self.from_packet,
to_packet=self.to_packet,
with_signal=self.with_signal,
with_signal_and_flush=self.with_signal_and_flush,
)
return fused_operation

View File

@@ -0,0 +1,78 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
from mscclpp.language.channel import *
from mscclpp.language.rank import *
from mscclpp.language.general import *
from mscclpp.language.program import *
from mscclpp.language.collectives import *
def allgather_example(name, gpu_size, num_threads_per_block, min_message_size, max_message_size):
chunksperloop = 1
collective = AllGather(gpu_size, chunksperloop, True)
with CollectiveProgram(
name,
collective,
gpu_size,
protocol="LL",
num_threads_per_block=num_threads_per_block,
use_double_scratch_buffer=True,
min_message_size=min_message_size,
max_message_size=max_message_size,
):
# Creating Scratch Buffers
scratch_buffer = []
for gpu in range(gpu_size):
scratch_buffer.append(Buffer(gpu, 2 * gpu_size))
# Copying it to scratch buffer
for gpu in range(gpu_size):
rank = Rank(gpu)
scratch_offset = gpu_size
input_buffer = rank.get_input_buffer()
rank.copy_packets(
scratch_buffer[gpu][scratch_offset + gpu : scratch_offset + gpu + 1], input_buffer[0:1], tb=0
)
# Putting packets in the remote scratch buffer
for gpu in range(gpu_size):
rank = Rank(gpu)
output_buffer = rank.get_output_buffer()
for peer in range(1, gpu_size):
dst_rank = (gpu + peer) % gpu_size
ch = MemoryChannel(dst_rank, gpu)
tb = 0
ch.read_put_packets(
scratch_buffer[dst_rank][gpu : gpu + 1],
scratch_buffer[gpu][scratch_offset + gpu : scratch_offset + gpu + 1],
tb,
)
# Copying packets from local scratch buffer to local buffer
for gpu in range(gpu_size):
rank = Rank(gpu)
output_buffer = rank.get_output_buffer()
for peer in range(1, gpu_size):
dst_rank = (gpu + peer) % gpu_size
rank.unpack_packets(
output_buffer[dst_rank : dst_rank + 1],
scratch_buffer[gpu][dst_rank : dst_rank + 1],
tb=0,
)
print(JSON())
parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, help="name of the program")
parser.add_argument("--num_gpus", type=int, help="number of gpus")
parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block")
parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size")
parser.add_argument("--max_message_size", type=int, default=2**64 - 1, help="maximum message size")
args = parser.parse_args()
allgather_example(args.name, args.num_gpus, args.num_threads_per_block, args.min_message_size, args.max_message_size)