Files
mscclpp/test/torch/allreduce_temp_buff.py
Binyang Li bd68319e3e Refactor algo selection logic and introduce symmetric_memory env (#741)
This PR refactors the algorithm selection logic in MSCCL++ and
introduces support for symmetric memory configuration through
environment variables.


1. Algorithm Selection Refactoring
Use separate class for algo selection. Could introduce more complex
logic for algo selection based on message size, arch, if cuda graph is
enabled and memory allocation method

2. Symmetric Memory Support
Introduced symmetricMemory parameter in algorithm context key
generation. Remove disableChannelCache env as is ambiguous

3. Add new args for build_default_algorithms 
Add flag_buffer, and flag_buffer_size args to build default algorithm.
Then we could use unified flag buffer for different algorithms, avoid
application hanging when switch algo for different message size.

---------

Co-authored-by: chhwang <8018170+chhwang@users.noreply.github.com>
Co-authored-by: Qinghua Zhou <qinghuazhou@microsoft.com>
Co-authored-by: Caio Rocha <caiorocha@microsoft.com>
2026-02-12 19:06:18 -08:00

92 lines
2.6 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# run with:
# LD_PRELOAD=<MSCCLPP_REPO>/build/lib/libmscclpp_nccl.so MSCCLPP_NCCL_SYMMETRIC_MEMORY=false torchrun --nproc_per_node=8 ./allreduce_temp_buff.py
import os
import torch
import torch.nn as nn
import torch.distributed as dist
def init_dist():
rank = int(os.environ["RANK"])
world = int(os.environ["WORLD_SIZE"])
local = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local)
dist.init_process_group("nccl")
return rank, world, local
class SimpleModel(nn.Module):
def __init__(self, DIN, DH, DOUT):
super().__init__()
self.layer1 = nn.Linear(DIN, DH, bias=False)
self.layer2 = nn.Linear(DH, DOUT, bias=False)
self.rank = int(os.environ["RANK"])
self.eval()
@torch.no_grad()
def forward(self, x_bf16: torch.Tensor, out_bf16: torch.Tensor):
"""
x_bf16: [B, DIN] (bf16) input
out_bf16: [B, DOUT] (bf16) output buffer
Returns:
out_bf16: [B, DOUT] (bf16) output
"""
out = self.layer1(x_bf16)
temp = torch.empty_like(out, dtype=torch.bfloat16)
temp.copy_(out)
dist.all_reduce(temp, op=dist.ReduceOp.SUM)
temp2 = temp
if self.rank == 0:
# If we are on rank 0, we can use a different temp buffer, make sure msccl++ can handle buffer address changes
temp2 = torch.empty_like(temp, dtype=torch.bfloat16)
temp2.copy_(temp)
dist.all_reduce(temp2, op=dist.ReduceOp.SUM)
out = self.layer2(temp2)
out_bf16.copy_(out)
return out
def main():
rank, _, local = init_dist()
device = torch.device("cuda", local)
torch.set_grad_enabled(False)
# message size B * DH * dtype_size = 32MB
B, DIN, DH, DOUT = 2048, 1024, 8192, 8
dtype = torch.bfloat16
# Warm up comms
dist.all_reduce(torch.ones(1, device=device).to(dtype))
# Build model
model = SimpleModel(DIN, DH, DOUT).to(device).to(dtype)
# Static I/O buffers for capture (stable addresses)
x_bf16 = torch.empty(B, DIN, dtype=dtype, device=device)
out_bf16 = torch.empty(B, DOUT, dtype=dtype, device=device)
# Eager warmup
x_bf16.normal_()
_ = model(x_bf16, out_bf16)
torch.cuda.synchronize()
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
model(x_bf16, out_bf16)
for step in range(5):
x_bf16.normal_()
g.replay()
if rank == 0:
print(f"[step {step}] out_mean={out_bf16.float().mean().item():.6f}")
dist.destroy_process_group()
if __name__ == "__main__":
main()