mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-07-03 22:07:12 +00:00
[diffusion]: Improve layerwise offload buffer reuse and shared-storage handling (#18611)
This commit is contained in:
@@ -65,9 +65,31 @@ class LayerwiseOffloadManager:
|
||||
self._named_buffers: Dict[str, torch.Tensor] = {}
|
||||
# Store forward hooks for removal
|
||||
self._forward_hooks: List[Any] = []
|
||||
# GPU buffer pool: dtype -> numel -> [buffers]
|
||||
self._gpu_buffer_pool: Dict[torch.dtype, Dict[int, List[torch.Tensor]]] = {}
|
||||
# layer_idx -> {dtype: gpu_buffer}
|
||||
self._layer_gpu_buffers: Dict[int, Dict[torch.dtype, torch.Tensor]] = {}
|
||||
self._gpu_buffer_pool_max = max(2, 2 * self.prefetch_size)
|
||||
# Keep layer 0 resident during the forward pass to avoid redundant reloads
|
||||
self._resident_window = 1
|
||||
|
||||
self._initialize()
|
||||
|
||||
def _acquire_gpu_buffer(self, dtype: torch.dtype, numel: int) -> torch.Tensor:
|
||||
pool_by_dtype = self._gpu_buffer_pool.setdefault(dtype, {})
|
||||
bucket = pool_by_dtype.get(numel)
|
||||
if bucket:
|
||||
return bucket.pop()
|
||||
return torch.empty((numel,), dtype=dtype, device=self.device)
|
||||
|
||||
def _release_gpu_buffer(
|
||||
self, dtype: torch.dtype, numel: int, buffer: torch.Tensor
|
||||
) -> None:
|
||||
pool_by_dtype = self._gpu_buffer_pool.setdefault(dtype, {})
|
||||
bucket = pool_by_dtype.setdefault(numel, [])
|
||||
if len(bucket) < self._gpu_buffer_pool_max:
|
||||
bucket.append(buffer)
|
||||
|
||||
def _match_layer_idx(self, name: str) -> int | None:
|
||||
m = self._layer_name_re.search(name)
|
||||
if not m:
|
||||
@@ -82,12 +104,14 @@ class LayerwiseOffloadManager:
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
self._named_parameters = dict(self.model.named_parameters())
|
||||
self._named_buffers = dict(self.model.named_buffers())
|
||||
named_parameters = list(self.model.named_parameters())
|
||||
named_buffers = list(self.model.named_buffers())
|
||||
self._named_parameters = dict(named_parameters)
|
||||
self._named_buffers = dict(named_buffers)
|
||||
|
||||
# 1. collect and group tensors by layer and dtype
|
||||
layer_groups: Dict[int, Dict[torch.dtype, List[Tuple[str, torch.Tensor]]]] = {}
|
||||
all_tensors = chain(self._named_parameters.items(), self._named_buffers.items())
|
||||
all_tensors = list(chain(named_parameters, named_buffers))
|
||||
for name, tensor in all_tensors:
|
||||
layer_idx = self._match_layer_idx(name)
|
||||
if layer_idx is None or layer_idx >= self.num_layers:
|
||||
@@ -173,9 +197,7 @@ class LayerwiseOffloadManager:
|
||||
gpu_buffers: Dict[torch.dtype, torch.Tensor] = {}
|
||||
with torch.cuda.stream(self.copy_stream):
|
||||
for dtype, cpu_buffer in self._consolidated_cpu_weights[layer_idx].items():
|
||||
gpu_buffer = torch.empty(
|
||||
cpu_buffer.shape, dtype=dtype, device=self.device
|
||||
)
|
||||
gpu_buffer = self._acquire_gpu_buffer(dtype, cpu_buffer.numel())
|
||||
gpu_buffer.copy_(cpu_buffer, non_blocking=non_blocking)
|
||||
gpu_buffers[dtype] = gpu_buffer
|
||||
|
||||
@@ -196,9 +218,10 @@ class LayerwiseOffloadManager:
|
||||
].view(meta["shape"])
|
||||
|
||||
self._gpu_layers.add(layer_idx)
|
||||
self._layer_gpu_buffers[layer_idx] = gpu_buffers
|
||||
|
||||
@torch.compiler.disable
|
||||
def release_layer(self, layer_idx: int) -> None:
|
||||
def release_layer(self, layer_idx: int, force: bool = False) -> None:
|
||||
"""
|
||||
lightweight release layer weights
|
||||
Basically set the reference count to the gpu weight tensor to zero. The weights on cpu is untouched
|
||||
@@ -209,7 +232,7 @@ class LayerwiseOffloadManager:
|
||||
# clear prefetch event, since it's useless and needs to be reset
|
||||
self._prefetch_events.pop(layer_idx, None)
|
||||
|
||||
if layer_idx <= 0:
|
||||
if layer_idx < self._resident_window and not force:
|
||||
return
|
||||
|
||||
if layer_idx not in self._gpu_layers:
|
||||
@@ -219,6 +242,11 @@ class LayerwiseOffloadManager:
|
||||
target = self.get_target_with_name(name)
|
||||
target.data = torch.empty((1,), device=self.device, dtype=meta["dtype"])
|
||||
|
||||
layer_buffers = self._layer_gpu_buffers.pop(layer_idx, None)
|
||||
if layer_buffers is not None:
|
||||
for dtype, buffer in layer_buffers.items():
|
||||
self._release_gpu_buffer(dtype, buffer.numel(), buffer)
|
||||
|
||||
self._gpu_layers.discard(layer_idx)
|
||||
|
||||
@torch.compiler.disable
|
||||
@@ -229,7 +257,9 @@ class LayerwiseOffloadManager:
|
||||
torch.cuda.current_stream().wait_stream(self.copy_stream)
|
||||
|
||||
for layer_idx in list(self._gpu_layers):
|
||||
self.release_layer(layer_idx)
|
||||
self.release_layer(layer_idx, force=True)
|
||||
|
||||
self._gpu_buffer_pool.clear()
|
||||
|
||||
@torch.compiler.disable
|
||||
def load_all_layers(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user