[Parallel State Refactor 1/n] Remove stream of PyNCCL (#20866)

This commit is contained in:
DarkSharpness
2026-04-03 00:47:50 +08:00
committed by GitHub
parent b21db86e2f
commit df94cdcebb
2 changed files with 34 additions and 86 deletions

View File

@@ -31,7 +31,6 @@ class PyNcclCommunicator:
group: Union[ProcessGroup, StatelessProcessGroup],
device: Union[int, str, torch.device],
library_path: Optional[str] = None,
use_current_stream: bool = False,
):
"""
Args:
@@ -62,7 +61,6 @@ class PyNcclCommunicator:
if self.world_size == 1:
self.available = False
self.disabled = True
self.stream = None
return
try:
self.nccl = NCCLLibrary(library_path)
@@ -71,12 +69,10 @@ class PyNcclCommunicator:
# e.g. in a non-GPU environment
self.available = False
self.disabled = True
self.stream = None
return
self.available = True
self.disabled = False
self.use_current_stream = use_current_stream
self.nccl_version = self.nccl.ncclGetRawVersion()
if self.rank == 0:
@@ -113,12 +109,13 @@ class PyNcclCommunicator:
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
self.world_size, self.unique_id, self.rank
)
self.stream = torch.cuda.Stream()
warmup_stream = torch.cuda.Stream()
# A small all_reduce for warmup.
data = torch.zeros(1, device=device)
self.all_reduce(data)
self.stream.synchronize()
with torch.cuda.stream(warmup_stream):
data = torch.zeros(1, device=device)
self.all_reduce(data)
warmup_stream.synchronize()
del data
# by default it is disabled, e.g. in profiling models and prefill phase.
@@ -126,24 +123,11 @@ class PyNcclCommunicator:
# when we are using CUDA graph.
self.disabled = True
def _resolve_stream(self, stream: Optional[torch.cuda.Stream]):
"""Return the stream to use for NCCL calls.
def _resolve_stream(self) -> torch.cuda.Stream:
"""Return the current device stream used for NCCL calls."""
return get_current_device_stream_fast()
Behavior mirrors the previous inline logic:
- if an explicit stream is provided, return it
- if stream is None and self.use_current_stream is True, return
torch.cuda.current_stream()
- otherwise return the communicator's default stream (self.stream)
"""
if stream is not None:
return stream
if self.use_current_stream:
return get_current_device_stream_fast()
return self.stream
def all_reduce(
self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None
):
def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM):
if self.disabled:
return
# nccl communicator created on a specific device
@@ -153,7 +137,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
stream = self._resolve_stream(stream)
stream = self._resolve_stream()
self.nccl.ncclAllReduce(
buffer_type(tensor.data_ptr()),
buffer_type(tensor.data_ptr()),
@@ -169,7 +153,6 @@ class PyNcclCommunicator:
in_tensor: torch.Tensor,
out_tensor: Optional[torch.Tensor] = None,
op: ReduceOp = ReduceOp.SUM,
stream=None,
) -> Optional[torch.Tensor]:
if self.disabled:
return None
@@ -181,7 +164,7 @@ class PyNcclCommunicator:
if out_tensor is None:
out_tensor = torch.empty_like(in_tensor)
stream = self._resolve_stream(stream)
stream = self._resolve_stream()
self.nccl.ncclAllReduce(
buffer_type(in_tensor.data_ptr()), # sendbuff
buffer_type(out_tensor.data_ptr()), # recvbuff - DIFFERENT pointer
@@ -197,7 +180,6 @@ class PyNcclCommunicator:
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
stream=None,
sizes: Optional[list[int]] = None,
):
if self.disabled:
@@ -209,7 +191,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}"
)
stream = self._resolve_stream(stream)
stream = self._resolve_stream()
if sizes is not None:
split_offset = 0
@@ -242,7 +224,7 @@ class PyNcclCommunicator:
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
stream=None,
stream: torch.cuda.Stream,
sizes: Optional[list[int]] = None,
):
"""
@@ -256,7 +238,6 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}"
)
stream = self._resolve_stream(stream)
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
@@ -271,7 +252,6 @@ class PyNcclCommunicator:
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None,
sizes: Optional[list[int]] = None,
):
if self.disabled:
@@ -283,7 +263,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}"
)
stream = self._resolve_stream(stream)
stream = self._resolve_stream()
if sizes is not None:
split_offset = 0
@@ -314,14 +294,14 @@ class PyNcclCommunicator:
cudaStream_t(stream.cuda_stream),
)
def send(self, tensor: torch.Tensor, dst: int, stream=None):
def send(self, tensor: torch.Tensor, dst: int):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
stream = self._resolve_stream(stream)
stream = self._resolve_stream()
self.nccl.ncclSend(
buffer_type(tensor.data_ptr()),
tensor.numel(),
@@ -331,14 +311,14 @@ class PyNcclCommunicator:
cudaStream_t(stream.cuda_stream),
)
def recv(self, tensor: torch.Tensor, src: int, stream=None):
def recv(self, tensor: torch.Tensor, src: int):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
stream = self._resolve_stream(stream)
stream = self._resolve_stream()
self.nccl.ncclRecv(
buffer_type(tensor.data_ptr()),
tensor.numel(),
@@ -348,14 +328,14 @@ class PyNcclCommunicator:
cudaStream_t(stream.cuda_stream),
)
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
def broadcast(self, tensor: torch.Tensor, src: int):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
stream = self._resolve_stream(stream)
stream = self._resolve_stream()
if src == self.rank:
sendbuff = buffer_type(tensor.data_ptr())
@@ -387,25 +367,17 @@ class PyNcclCommunicator:
self.nccl.ncclGroupEnd()
@contextmanager
def change_state(
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
):
def change_state(self, enable: Optional[bool] = None):
"""
A context manager to change the state of the communicator.
A context manager to change the enabled state of the communicator.
"""
if enable is None:
# guess a default value when not specified
enable = self.available
if stream is None:
stream = self.stream
old_disable = self.disabled
old_stream = self.stream
self.stream = stream
self.disabled = not enable
yield
self.disabled = old_disable
self.stream = old_stream
try:
yield
finally:
self.disabled = old_disable

View File

@@ -246,7 +246,6 @@ class GroupCoordinator:
use_npu_communicator: bool,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
pynccl_use_current_stream: bool = False,
gloo_timeout: timedelta = timedelta(seconds=120 * 60),
):
# Set group info
@@ -316,7 +315,6 @@ class GroupCoordinator:
# Import communicators
self.use_pynccl = use_pynccl
self.pynccl_use_current_stream = pynccl_use_current_stream
self.use_pymscclpp = use_pymscclpp
self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem_all_reduce = use_torch_symm_mem_all_reduce
@@ -358,7 +356,6 @@ class GroupCoordinator:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
device=self.device,
use_current_stream=pynccl_use_current_stream,
)
self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
@@ -533,9 +530,7 @@ class GroupCoordinator:
if not pynccl_comm:
maybe_pynccl_context = nullcontext()
else:
maybe_pynccl_context = pynccl_comm.change_state(
enable=True, stream=get_current_device_stream_fast()
)
maybe_pynccl_context = pynccl_comm.change_state(enable=True)
pymscclpp_comm = self.pymscclpp_comm
maybe_pymscclpp_context: Any
@@ -602,9 +597,7 @@ class GroupCoordinator:
return self.npu_communicator.all_reduce(input_)
if self.pynccl_comm is not None and self.is_symmetric_memory_enabled():
with self.pynccl_comm.change_state(
enable=True, stream=get_current_device_stream_fast()
):
with self.pynccl_comm.change_state(enable=True):
self.pynccl_comm.all_reduce(input_)
return input_
@@ -720,9 +713,7 @@ class GroupCoordinator:
assert not pymscclpp_comm.disabled
out = pymscclpp_comm.all_reduce(input_)
elif outplace_all_reduce_method == "pynccl":
with pynccl_comm.change_state(
enable=True, stream=get_current_device_stream_fast()
):
with pynccl_comm.change_state(enable=True):
out = pynccl_comm.outplace_all_reduce(input_)
assert out is not None
return out
@@ -746,9 +737,7 @@ class GroupCoordinator:
if pynccl_comm is not None and (
not pynccl_comm.disabled or self.is_symmetric_memory_enabled()
):
with pynccl_comm.change_state(
enable=True, stream=get_current_device_stream_fast()
):
with pynccl_comm.change_state(enable=True):
pynccl_comm.reduce_scatter(output, input)
else:
torch.distributed.reduce_scatter_tensor(
@@ -780,9 +769,7 @@ class GroupCoordinator:
world_size = self.world_size
pynccl_comm = self.pynccl_comm
with pynccl_comm.change_state(
enable=True, stream=get_current_device_stream_fast()
):
with pynccl_comm.change_state(enable=True):
assert (
pynccl_comm is not None and not pynccl_comm.disabled
), "pynccl is required for reduce_scatterv"
@@ -811,9 +798,7 @@ class GroupCoordinator:
if pynccl_comm is not None and (
not pynccl_comm.disabled or self.is_symmetric_memory_enabled()
):
with pynccl_comm.change_state(
enable=True, stream=get_current_device_stream_fast()
):
with pynccl_comm.change_state(enable=True):
pynccl_comm.all_gather(output, input)
else:
torch.distributed.all_gather_into_tensor(
@@ -827,7 +812,7 @@ class GroupCoordinator:
reg_all_gather_into_tensor(output, input, group_name=self.unique_name)
def cp_all_gather_into_tensor_async(
self, output: torch.Tensor, input: torch.Tensor, stream=None
self, output: torch.Tensor, input: torch.Tensor, stream: torch.cuda.Stream
):
"""
Implement an asynchronous `allgather` operation on a specified stream.
@@ -835,9 +820,6 @@ class GroupCoordinator:
eliminating the CPU-side launch-kernel blocking issue caused by synchronization problems.
The specific implementation uses the interface provided by pynccl to remove the synchronization logic of events.
"""
assert (
stream is not None
), f"Invalid params stream ({stream}, Please specify the stream to use when calling cp_all_gather_into_tensor_async.)"
pynccl_comm = self.pynccl_comm
if pynccl_comm is None or pynccl_comm.disabled:
self.all_gather_into_tensor(output, input)
@@ -930,9 +912,7 @@ class GroupCoordinator:
world_size = self.world_size
pynccl_comm = self.pynccl_comm
with pynccl_comm.change_state(
enable=True, stream=get_current_device_stream_fast()
):
with pynccl_comm.change_state(enable=True):
assert (
pynccl_comm is not None and not pynccl_comm.disabled
), "pynccl is required for all_gatherv"
@@ -1439,7 +1419,6 @@ def init_model_parallel_group(
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
use_mscclpp_allreduce: Optional[bool] = None,
pynccl_use_current_stream: bool = True,
use_torch_symm_mem_allreduce: Optional[bool] = None,
) -> GroupCoordinator:
if use_custom_allreduce is None:
@@ -1465,7 +1444,6 @@ def init_model_parallel_group(
use_npu_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
pynccl_use_current_stream=pynccl_use_current_stream,
)
@@ -1835,7 +1813,6 @@ def initialize_model_parallel(
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
),
group_name="tp",
pynccl_use_current_stream=duplicate_tp_group,
)
if duplicate_tp_group:
@@ -1851,7 +1828,6 @@ def initialize_model_parallel(
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
),
group_name="pdmux_prefill_tp",
pynccl_use_current_stream=True,
)
if _TP.pynccl_comm:
_TP.pynccl_comm.disabled = False