diff --git a/python/sglang/srt/distributed/device_communicators/pynccl.py b/python/sglang/srt/distributed/device_communicators/pynccl.py index 660582ad3..eccbc872e 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl.py @@ -31,7 +31,6 @@ class PyNcclCommunicator: group: Union[ProcessGroup, StatelessProcessGroup], device: Union[int, str, torch.device], library_path: Optional[str] = None, - use_current_stream: bool = False, ): """ Args: @@ -62,7 +61,6 @@ class PyNcclCommunicator: if self.world_size == 1: self.available = False self.disabled = True - self.stream = None return try: self.nccl = NCCLLibrary(library_path) @@ -71,12 +69,10 @@ class PyNcclCommunicator: # e.g. in a non-GPU environment self.available = False self.disabled = True - self.stream = None return self.available = True self.disabled = False - self.use_current_stream = use_current_stream self.nccl_version = self.nccl.ncclGetRawVersion() if self.rank == 0: @@ -113,12 +109,13 @@ class PyNcclCommunicator: self.comm: ncclComm_t = self.nccl.ncclCommInitRank( self.world_size, self.unique_id, self.rank ) - self.stream = torch.cuda.Stream() + warmup_stream = torch.cuda.Stream() # A small all_reduce for warmup. - data = torch.zeros(1, device=device) - self.all_reduce(data) - self.stream.synchronize() + with torch.cuda.stream(warmup_stream): + data = torch.zeros(1, device=device) + self.all_reduce(data) + warmup_stream.synchronize() del data # by default it is disabled, e.g. in profiling models and prefill phase. @@ -126,24 +123,11 @@ class PyNcclCommunicator: # when we are using CUDA graph. self.disabled = True - def _resolve_stream(self, stream: Optional[torch.cuda.Stream]): - """Return the stream to use for NCCL calls. + def _resolve_stream(self) -> torch.cuda.Stream: + """Return the current device stream used for NCCL calls.""" + return get_current_device_stream_fast() - Behavior mirrors the previous inline logic: - - if an explicit stream is provided, return it - - if stream is None and self.use_current_stream is True, return - torch.cuda.current_stream() - - otherwise return the communicator's default stream (self.stream) - """ - if stream is not None: - return stream - if self.use_current_stream: - return get_current_device_stream_fast() - return self.stream - - def all_reduce( - self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None - ): + def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM): if self.disabled: return # nccl communicator created on a specific device @@ -153,7 +137,7 @@ class PyNcclCommunicator: f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}" ) - stream = self._resolve_stream(stream) + stream = self._resolve_stream() self.nccl.ncclAllReduce( buffer_type(tensor.data_ptr()), buffer_type(tensor.data_ptr()), @@ -169,7 +153,6 @@ class PyNcclCommunicator: in_tensor: torch.Tensor, out_tensor: Optional[torch.Tensor] = None, op: ReduceOp = ReduceOp.SUM, - stream=None, ) -> Optional[torch.Tensor]: if self.disabled: return None @@ -181,7 +164,7 @@ class PyNcclCommunicator: if out_tensor is None: out_tensor = torch.empty_like(in_tensor) - stream = self._resolve_stream(stream) + stream = self._resolve_stream() self.nccl.ncclAllReduce( buffer_type(in_tensor.data_ptr()), # sendbuff buffer_type(out_tensor.data_ptr()), # recvbuff - DIFFERENT pointer @@ -197,7 +180,6 @@ class PyNcclCommunicator: self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, - stream=None, sizes: Optional[list[int]] = None, ): if self.disabled: @@ -209,7 +191,7 @@ class PyNcclCommunicator: f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {input_tensor.device}" ) - stream = self._resolve_stream(stream) + stream = self._resolve_stream() if sizes is not None: split_offset = 0 @@ -242,7 +224,7 @@ class PyNcclCommunicator: self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, - stream=None, + stream: torch.cuda.Stream, sizes: Optional[list[int]] = None, ): """ @@ -256,7 +238,6 @@ class PyNcclCommunicator: f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {input_tensor.device}" ) - stream = self._resolve_stream(stream) self.nccl.ncclAllGather( buffer_type(input_tensor.data_ptr()), buffer_type(output_tensor.data_ptr()), @@ -271,7 +252,6 @@ class PyNcclCommunicator: output_tensor: torch.Tensor, input_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, - stream=None, sizes: Optional[list[int]] = None, ): if self.disabled: @@ -283,7 +263,7 @@ class PyNcclCommunicator: f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {input_tensor.device}" ) - stream = self._resolve_stream(stream) + stream = self._resolve_stream() if sizes is not None: split_offset = 0 @@ -314,14 +294,14 @@ class PyNcclCommunicator: cudaStream_t(stream.cuda_stream), ) - def send(self, tensor: torch.Tensor, dst: int, stream=None): + def send(self, tensor: torch.Tensor, dst: int): if self.disabled: return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}" ) - stream = self._resolve_stream(stream) + stream = self._resolve_stream() self.nccl.ncclSend( buffer_type(tensor.data_ptr()), tensor.numel(), @@ -331,14 +311,14 @@ class PyNcclCommunicator: cudaStream_t(stream.cuda_stream), ) - def recv(self, tensor: torch.Tensor, src: int, stream=None): + def recv(self, tensor: torch.Tensor, src: int): if self.disabled: return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}" ) - stream = self._resolve_stream(stream) + stream = self._resolve_stream() self.nccl.ncclRecv( buffer_type(tensor.data_ptr()), tensor.numel(), @@ -348,14 +328,14 @@ class PyNcclCommunicator: cudaStream_t(stream.cuda_stream), ) - def broadcast(self, tensor: torch.Tensor, src: int, stream=None): + def broadcast(self, tensor: torch.Tensor, src: int): if self.disabled: return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}" ) - stream = self._resolve_stream(stream) + stream = self._resolve_stream() if src == self.rank: sendbuff = buffer_type(tensor.data_ptr()) @@ -387,25 +367,17 @@ class PyNcclCommunicator: self.nccl.ncclGroupEnd() @contextmanager - def change_state( - self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None - ): + def change_state(self, enable: Optional[bool] = None): """ - A context manager to change the state of the communicator. + A context manager to change the enabled state of the communicator. """ if enable is None: # guess a default value when not specified enable = self.available - if stream is None: - stream = self.stream - old_disable = self.disabled - old_stream = self.stream - - self.stream = stream self.disabled = not enable - yield - - self.disabled = old_disable - self.stream = old_stream + try: + yield + finally: + self.disabled = old_disable diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 72311f9d3..abe3c218a 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -246,7 +246,6 @@ class GroupCoordinator: use_npu_communicator: bool, use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, - pynccl_use_current_stream: bool = False, gloo_timeout: timedelta = timedelta(seconds=120 * 60), ): # Set group info @@ -316,7 +315,6 @@ class GroupCoordinator: # Import communicators self.use_pynccl = use_pynccl - self.pynccl_use_current_stream = pynccl_use_current_stream self.use_pymscclpp = use_pymscclpp self.use_custom_allreduce = use_custom_allreduce self.use_torch_symm_mem_all_reduce = use_torch_symm_mem_all_reduce @@ -358,7 +356,6 @@ class GroupCoordinator: self.pynccl_comm = PyNcclCommunicator( group=self.cpu_group, device=self.device, - use_current_stream=pynccl_use_current_stream, ) self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None @@ -533,9 +530,7 @@ class GroupCoordinator: if not pynccl_comm: maybe_pynccl_context = nullcontext() else: - maybe_pynccl_context = pynccl_comm.change_state( - enable=True, stream=get_current_device_stream_fast() - ) + maybe_pynccl_context = pynccl_comm.change_state(enable=True) pymscclpp_comm = self.pymscclpp_comm maybe_pymscclpp_context: Any @@ -602,9 +597,7 @@ class GroupCoordinator: return self.npu_communicator.all_reduce(input_) if self.pynccl_comm is not None and self.is_symmetric_memory_enabled(): - with self.pynccl_comm.change_state( - enable=True, stream=get_current_device_stream_fast() - ): + with self.pynccl_comm.change_state(enable=True): self.pynccl_comm.all_reduce(input_) return input_ @@ -720,9 +713,7 @@ class GroupCoordinator: assert not pymscclpp_comm.disabled out = pymscclpp_comm.all_reduce(input_) elif outplace_all_reduce_method == "pynccl": - with pynccl_comm.change_state( - enable=True, stream=get_current_device_stream_fast() - ): + with pynccl_comm.change_state(enable=True): out = pynccl_comm.outplace_all_reduce(input_) assert out is not None return out @@ -746,9 +737,7 @@ class GroupCoordinator: if pynccl_comm is not None and ( not pynccl_comm.disabled or self.is_symmetric_memory_enabled() ): - with pynccl_comm.change_state( - enable=True, stream=get_current_device_stream_fast() - ): + with pynccl_comm.change_state(enable=True): pynccl_comm.reduce_scatter(output, input) else: torch.distributed.reduce_scatter_tensor( @@ -780,9 +769,7 @@ class GroupCoordinator: world_size = self.world_size pynccl_comm = self.pynccl_comm - with pynccl_comm.change_state( - enable=True, stream=get_current_device_stream_fast() - ): + with pynccl_comm.change_state(enable=True): assert ( pynccl_comm is not None and not pynccl_comm.disabled ), "pynccl is required for reduce_scatterv" @@ -811,9 +798,7 @@ class GroupCoordinator: if pynccl_comm is not None and ( not pynccl_comm.disabled or self.is_symmetric_memory_enabled() ): - with pynccl_comm.change_state( - enable=True, stream=get_current_device_stream_fast() - ): + with pynccl_comm.change_state(enable=True): pynccl_comm.all_gather(output, input) else: torch.distributed.all_gather_into_tensor( @@ -827,7 +812,7 @@ class GroupCoordinator: reg_all_gather_into_tensor(output, input, group_name=self.unique_name) def cp_all_gather_into_tensor_async( - self, output: torch.Tensor, input: torch.Tensor, stream=None + self, output: torch.Tensor, input: torch.Tensor, stream: torch.cuda.Stream ): """ Implement an asynchronous `allgather` operation on a specified stream. @@ -835,9 +820,6 @@ class GroupCoordinator: eliminating the CPU-side launch-kernel blocking issue caused by synchronization problems. The specific implementation uses the interface provided by pynccl to remove the synchronization logic of events. """ - assert ( - stream is not None - ), f"Invalid params stream ({stream}, Please specify the stream to use when calling cp_all_gather_into_tensor_async.)" pynccl_comm = self.pynccl_comm if pynccl_comm is None or pynccl_comm.disabled: self.all_gather_into_tensor(output, input) @@ -930,9 +912,7 @@ class GroupCoordinator: world_size = self.world_size pynccl_comm = self.pynccl_comm - with pynccl_comm.change_state( - enable=True, stream=get_current_device_stream_fast() - ): + with pynccl_comm.change_state(enable=True): assert ( pynccl_comm is not None and not pynccl_comm.disabled ), "pynccl is required for all_gatherv" @@ -1439,7 +1419,6 @@ def init_model_parallel_group( use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, use_mscclpp_allreduce: Optional[bool] = None, - pynccl_use_current_stream: bool = True, use_torch_symm_mem_allreduce: Optional[bool] = None, ) -> GroupCoordinator: if use_custom_allreduce is None: @@ -1465,7 +1444,6 @@ def init_model_parallel_group( use_npu_communicator=True, use_message_queue_broadcaster=use_message_queue_broadcaster, group_name=group_name, - pynccl_use_current_stream=pynccl_use_current_stream, ) @@ -1835,7 +1813,6 @@ def initialize_model_parallel( "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" ), group_name="tp", - pynccl_use_current_stream=duplicate_tp_group, ) if duplicate_tp_group: @@ -1851,7 +1828,6 @@ def initialize_model_parallel( "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" ), group_name="pdmux_prefill_tp", - pynccl_use_current_stream=True, ) if _TP.pynccl_comm: _TP.pynccl_comm.disabled = False