mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-07-01 20:27:57 +00:00
[Parallel State Refactor 1/n] Remove stream of PyNCCL (#20866)
This commit is contained in:
@@ -31,7 +31,6 @@ class PyNcclCommunicator:
|
||||
group: Union[ProcessGroup, StatelessProcessGroup],
|
||||
device: Union[int, str, torch.device],
|
||||
library_path: Optional[str] = None,
|
||||
use_current_stream: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -62,7 +61,6 @@ class PyNcclCommunicator:
|
||||
if self.world_size == 1:
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
self.stream = None
|
||||
return
|
||||
try:
|
||||
self.nccl = NCCLLibrary(library_path)
|
||||
@@ -71,12 +69,10 @@ class PyNcclCommunicator:
|
||||
# e.g. in a non-GPU environment
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
self.stream = None
|
||||
return
|
||||
|
||||
self.available = True
|
||||
self.disabled = False
|
||||
self.use_current_stream = use_current_stream
|
||||
|
||||
self.nccl_version = self.nccl.ncclGetRawVersion()
|
||||
if self.rank == 0:
|
||||
@@ -113,12 +109,13 @@ class PyNcclCommunicator:
|
||||
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
self.world_size, self.unique_id, self.rank
|
||||
)
|
||||
self.stream = torch.cuda.Stream()
|
||||
warmup_stream = torch.cuda.Stream()
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
data = torch.zeros(1, device=device)
|
||||
self.all_reduce(data)
|
||||
self.stream.synchronize()
|
||||
with torch.cuda.stream(warmup_stream):
|
||||
data = torch.zeros(1, device=device)
|
||||
self.all_reduce(data)
|
||||
warmup_stream.synchronize()
|
||||
del data
|
||||
|
||||
# by default it is disabled, e.g. in profiling models and prefill phase.
|
||||
@@ -126,24 +123,11 @@ class PyNcclCommunicator:
|
||||
# when we are using CUDA graph.
|
||||
self.disabled = True
|
||||
|
||||
def _resolve_stream(self, stream: Optional[torch.cuda.Stream]):
|
||||
"""Return the stream to use for NCCL calls.
|
||||
def _resolve_stream(self) -> torch.cuda.Stream:
|
||||
"""Return the current device stream used for NCCL calls."""
|
||||
return get_current_device_stream_fast()
|
||||
|
||||
Behavior mirrors the previous inline logic:
|
||||
- if an explicit stream is provided, return it
|
||||
- if stream is None and self.use_current_stream is True, return
|
||||
torch.cuda.current_stream()
|
||||
- otherwise return the communicator's default stream (self.stream)
|
||||
"""
|
||||
if stream is not None:
|
||||
return stream
|
||||
if self.use_current_stream:
|
||||
return get_current_device_stream_fast()
|
||||
return self.stream
|
||||
|
||||
def all_reduce(
|
||||
self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None
|
||||
):
|
||||
def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
@@ -153,7 +137,7 @@ class PyNcclCommunicator:
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
stream = self._resolve_stream(stream)
|
||||
stream = self._resolve_stream()
|
||||
self.nccl.ncclAllReduce(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
buffer_type(tensor.data_ptr()),
|
||||
@@ -169,7 +153,6 @@ class PyNcclCommunicator:
|
||||
in_tensor: torch.Tensor,
|
||||
out_tensor: Optional[torch.Tensor] = None,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if self.disabled:
|
||||
return None
|
||||
@@ -181,7 +164,7 @@ class PyNcclCommunicator:
|
||||
if out_tensor is None:
|
||||
out_tensor = torch.empty_like(in_tensor)
|
||||
|
||||
stream = self._resolve_stream(stream)
|
||||
stream = self._resolve_stream()
|
||||
self.nccl.ncclAllReduce(
|
||||
buffer_type(in_tensor.data_ptr()), # sendbuff
|
||||
buffer_type(out_tensor.data_ptr()), # recvbuff - DIFFERENT pointer
|
||||
@@ -197,7 +180,6 @@ class PyNcclCommunicator:
|
||||
self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
stream=None,
|
||||
sizes: Optional[list[int]] = None,
|
||||
):
|
||||
if self.disabled:
|
||||
@@ -209,7 +191,7 @@ class PyNcclCommunicator:
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}"
|
||||
)
|
||||
stream = self._resolve_stream(stream)
|
||||
stream = self._resolve_stream()
|
||||
|
||||
if sizes is not None:
|
||||
split_offset = 0
|
||||
@@ -242,7 +224,7 @@ class PyNcclCommunicator:
|
||||
self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
stream=None,
|
||||
stream: torch.cuda.Stream,
|
||||
sizes: Optional[list[int]] = None,
|
||||
):
|
||||
"""
|
||||
@@ -256,7 +238,6 @@ class PyNcclCommunicator:
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}"
|
||||
)
|
||||
stream = self._resolve_stream(stream)
|
||||
self.nccl.ncclAllGather(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
@@ -271,7 +252,6 @@ class PyNcclCommunicator:
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None,
|
||||
sizes: Optional[list[int]] = None,
|
||||
):
|
||||
if self.disabled:
|
||||
@@ -283,7 +263,7 @@ class PyNcclCommunicator:
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}"
|
||||
)
|
||||
stream = self._resolve_stream(stream)
|
||||
stream = self._resolve_stream()
|
||||
|
||||
if sizes is not None:
|
||||
split_offset = 0
|
||||
@@ -314,14 +294,14 @@ class PyNcclCommunicator:
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
||||
def send(self, tensor: torch.Tensor, dst: int):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
stream = self._resolve_stream(stream)
|
||||
stream = self._resolve_stream()
|
||||
self.nccl.ncclSend(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
@@ -331,14 +311,14 @@ class PyNcclCommunicator:
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def recv(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
def recv(self, tensor: torch.Tensor, src: int):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
stream = self._resolve_stream(stream)
|
||||
stream = self._resolve_stream()
|
||||
self.nccl.ncclRecv(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
@@ -348,14 +328,14 @@ class PyNcclCommunicator:
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
def broadcast(self, tensor: torch.Tensor, src: int):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
stream = self._resolve_stream(stream)
|
||||
stream = self._resolve_stream()
|
||||
|
||||
if src == self.rank:
|
||||
sendbuff = buffer_type(tensor.data_ptr())
|
||||
@@ -387,25 +367,17 @@ class PyNcclCommunicator:
|
||||
self.nccl.ncclGroupEnd()
|
||||
|
||||
@contextmanager
|
||||
def change_state(
|
||||
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
|
||||
):
|
||||
def change_state(self, enable: Optional[bool] = None):
|
||||
"""
|
||||
A context manager to change the state of the communicator.
|
||||
A context manager to change the enabled state of the communicator.
|
||||
"""
|
||||
if enable is None:
|
||||
# guess a default value when not specified
|
||||
enable = self.available
|
||||
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
|
||||
old_disable = self.disabled
|
||||
old_stream = self.stream
|
||||
|
||||
self.stream = stream
|
||||
self.disabled = not enable
|
||||
yield
|
||||
|
||||
self.disabled = old_disable
|
||||
self.stream = old_stream
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.disabled = old_disable
|
||||
|
||||
@@ -246,7 +246,6 @@ class GroupCoordinator:
|
||||
use_npu_communicator: bool,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: Optional[str] = None,
|
||||
pynccl_use_current_stream: bool = False,
|
||||
gloo_timeout: timedelta = timedelta(seconds=120 * 60),
|
||||
):
|
||||
# Set group info
|
||||
@@ -316,7 +315,6 @@ class GroupCoordinator:
|
||||
|
||||
# Import communicators
|
||||
self.use_pynccl = use_pynccl
|
||||
self.pynccl_use_current_stream = pynccl_use_current_stream
|
||||
self.use_pymscclpp = use_pymscclpp
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
self.use_torch_symm_mem_all_reduce = use_torch_symm_mem_all_reduce
|
||||
@@ -358,7 +356,6 @@ class GroupCoordinator:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
use_current_stream=pynccl_use_current_stream,
|
||||
)
|
||||
|
||||
self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
|
||||
@@ -533,9 +530,7 @@ class GroupCoordinator:
|
||||
if not pynccl_comm:
|
||||
maybe_pynccl_context = nullcontext()
|
||||
else:
|
||||
maybe_pynccl_context = pynccl_comm.change_state(
|
||||
enable=True, stream=get_current_device_stream_fast()
|
||||
)
|
||||
maybe_pynccl_context = pynccl_comm.change_state(enable=True)
|
||||
|
||||
pymscclpp_comm = self.pymscclpp_comm
|
||||
maybe_pymscclpp_context: Any
|
||||
@@ -602,9 +597,7 @@ class GroupCoordinator:
|
||||
return self.npu_communicator.all_reduce(input_)
|
||||
|
||||
if self.pynccl_comm is not None and self.is_symmetric_memory_enabled():
|
||||
with self.pynccl_comm.change_state(
|
||||
enable=True, stream=get_current_device_stream_fast()
|
||||
):
|
||||
with self.pynccl_comm.change_state(enable=True):
|
||||
self.pynccl_comm.all_reduce(input_)
|
||||
return input_
|
||||
|
||||
@@ -720,9 +713,7 @@ class GroupCoordinator:
|
||||
assert not pymscclpp_comm.disabled
|
||||
out = pymscclpp_comm.all_reduce(input_)
|
||||
elif outplace_all_reduce_method == "pynccl":
|
||||
with pynccl_comm.change_state(
|
||||
enable=True, stream=get_current_device_stream_fast()
|
||||
):
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
out = pynccl_comm.outplace_all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
@@ -746,9 +737,7 @@ class GroupCoordinator:
|
||||
if pynccl_comm is not None and (
|
||||
not pynccl_comm.disabled or self.is_symmetric_memory_enabled()
|
||||
):
|
||||
with pynccl_comm.change_state(
|
||||
enable=True, stream=get_current_device_stream_fast()
|
||||
):
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
pynccl_comm.reduce_scatter(output, input)
|
||||
else:
|
||||
torch.distributed.reduce_scatter_tensor(
|
||||
@@ -780,9 +769,7 @@ class GroupCoordinator:
|
||||
world_size = self.world_size
|
||||
pynccl_comm = self.pynccl_comm
|
||||
|
||||
with pynccl_comm.change_state(
|
||||
enable=True, stream=get_current_device_stream_fast()
|
||||
):
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
assert (
|
||||
pynccl_comm is not None and not pynccl_comm.disabled
|
||||
), "pynccl is required for reduce_scatterv"
|
||||
@@ -811,9 +798,7 @@ class GroupCoordinator:
|
||||
if pynccl_comm is not None and (
|
||||
not pynccl_comm.disabled or self.is_symmetric_memory_enabled()
|
||||
):
|
||||
with pynccl_comm.change_state(
|
||||
enable=True, stream=get_current_device_stream_fast()
|
||||
):
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
pynccl_comm.all_gather(output, input)
|
||||
else:
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
@@ -827,7 +812,7 @@ class GroupCoordinator:
|
||||
reg_all_gather_into_tensor(output, input, group_name=self.unique_name)
|
||||
|
||||
def cp_all_gather_into_tensor_async(
|
||||
self, output: torch.Tensor, input: torch.Tensor, stream=None
|
||||
self, output: torch.Tensor, input: torch.Tensor, stream: torch.cuda.Stream
|
||||
):
|
||||
"""
|
||||
Implement an asynchronous `allgather` operation on a specified stream.
|
||||
@@ -835,9 +820,6 @@ class GroupCoordinator:
|
||||
eliminating the CPU-side launch-kernel blocking issue caused by synchronization problems.
|
||||
The specific implementation uses the interface provided by pynccl to remove the synchronization logic of events.
|
||||
"""
|
||||
assert (
|
||||
stream is not None
|
||||
), f"Invalid params stream ({stream}, Please specify the stream to use when calling cp_all_gather_into_tensor_async.)"
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is None or pynccl_comm.disabled:
|
||||
self.all_gather_into_tensor(output, input)
|
||||
@@ -930,9 +912,7 @@ class GroupCoordinator:
|
||||
world_size = self.world_size
|
||||
pynccl_comm = self.pynccl_comm
|
||||
|
||||
with pynccl_comm.change_state(
|
||||
enable=True, stream=get_current_device_stream_fast()
|
||||
):
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
assert (
|
||||
pynccl_comm is not None and not pynccl_comm.disabled
|
||||
), "pynccl is required for all_gatherv"
|
||||
@@ -1439,7 +1419,6 @@ def init_model_parallel_group(
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: Optional[str] = None,
|
||||
use_mscclpp_allreduce: Optional[bool] = None,
|
||||
pynccl_use_current_stream: bool = True,
|
||||
use_torch_symm_mem_allreduce: Optional[bool] = None,
|
||||
) -> GroupCoordinator:
|
||||
if use_custom_allreduce is None:
|
||||
@@ -1465,7 +1444,6 @@ def init_model_parallel_group(
|
||||
use_npu_communicator=True,
|
||||
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
||||
group_name=group_name,
|
||||
pynccl_use_current_stream=pynccl_use_current_stream,
|
||||
)
|
||||
|
||||
|
||||
@@ -1835,7 +1813,6 @@ def initialize_model_parallel(
|
||||
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
|
||||
),
|
||||
group_name="tp",
|
||||
pynccl_use_current_stream=duplicate_tp_group,
|
||||
)
|
||||
|
||||
if duplicate_tp_group:
|
||||
@@ -1851,7 +1828,6 @@ def initialize_model_parallel(
|
||||
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
|
||||
),
|
||||
group_name="pdmux_prefill_tp",
|
||||
pynccl_use_current_stream=True,
|
||||
)
|
||||
if _TP.pynccl_comm:
|
||||
_TP.pynccl_comm.disabled = False
|
||||
|
||||
Reference in New Issue
Block a user