[diffusion]: Improve layerwise offload buffer reuse and shared-storage handling (#18611)

This commit is contained in:
Ratish P
2026-02-15 19:47:51 +05:30
committed by GitHub
parent 4e162d4b1b
commit ddfe147377

View File

@@ -65,9 +65,31 @@ class LayerwiseOffloadManager:
self._named_buffers: Dict[str, torch.Tensor] = {}
# Store forward hooks for removal
self._forward_hooks: List[Any] = []
# GPU buffer pool: dtype -> numel -> [buffers]
self._gpu_buffer_pool: Dict[torch.dtype, Dict[int, List[torch.Tensor]]] = {}
# layer_idx -> {dtype: gpu_buffer}
self._layer_gpu_buffers: Dict[int, Dict[torch.dtype, torch.Tensor]] = {}
self._gpu_buffer_pool_max = max(2, 2 * self.prefetch_size)
# Keep layer 0 resident during the forward pass to avoid redundant reloads
self._resident_window = 1
self._initialize()
def _acquire_gpu_buffer(self, dtype: torch.dtype, numel: int) -> torch.Tensor:
pool_by_dtype = self._gpu_buffer_pool.setdefault(dtype, {})
bucket = pool_by_dtype.get(numel)
if bucket:
return bucket.pop()
return torch.empty((numel,), dtype=dtype, device=self.device)
def _release_gpu_buffer(
self, dtype: torch.dtype, numel: int, buffer: torch.Tensor
) -> None:
pool_by_dtype = self._gpu_buffer_pool.setdefault(dtype, {})
bucket = pool_by_dtype.setdefault(numel, [])
if len(bucket) < self._gpu_buffer_pool_max:
bucket.append(buffer)
def _match_layer_idx(self, name: str) -> int | None:
m = self._layer_name_re.search(name)
if not m:
@@ -82,12 +104,14 @@ class LayerwiseOffloadManager:
if not self.enabled:
return
self._named_parameters = dict(self.model.named_parameters())
self._named_buffers = dict(self.model.named_buffers())
named_parameters = list(self.model.named_parameters())
named_buffers = list(self.model.named_buffers())
self._named_parameters = dict(named_parameters)
self._named_buffers = dict(named_buffers)
# 1. collect and group tensors by layer and dtype
layer_groups: Dict[int, Dict[torch.dtype, List[Tuple[str, torch.Tensor]]]] = {}
all_tensors = chain(self._named_parameters.items(), self._named_buffers.items())
all_tensors = list(chain(named_parameters, named_buffers))
for name, tensor in all_tensors:
layer_idx = self._match_layer_idx(name)
if layer_idx is None or layer_idx >= self.num_layers:
@@ -173,9 +197,7 @@ class LayerwiseOffloadManager:
gpu_buffers: Dict[torch.dtype, torch.Tensor] = {}
with torch.cuda.stream(self.copy_stream):
for dtype, cpu_buffer in self._consolidated_cpu_weights[layer_idx].items():
gpu_buffer = torch.empty(
cpu_buffer.shape, dtype=dtype, device=self.device
)
gpu_buffer = self._acquire_gpu_buffer(dtype, cpu_buffer.numel())
gpu_buffer.copy_(cpu_buffer, non_blocking=non_blocking)
gpu_buffers[dtype] = gpu_buffer
@@ -196,9 +218,10 @@ class LayerwiseOffloadManager:
].view(meta["shape"])
self._gpu_layers.add(layer_idx)
self._layer_gpu_buffers[layer_idx] = gpu_buffers
@torch.compiler.disable
def release_layer(self, layer_idx: int) -> None:
def release_layer(self, layer_idx: int, force: bool = False) -> None:
"""
lightweight release layer weights
Basically set the reference count to the gpu weight tensor to zero. The weights on cpu is untouched
@@ -209,7 +232,7 @@ class LayerwiseOffloadManager:
# clear prefetch event, since it's useless and needs to be reset
self._prefetch_events.pop(layer_idx, None)
if layer_idx <= 0:
if layer_idx < self._resident_window and not force:
return
if layer_idx not in self._gpu_layers:
@@ -219,6 +242,11 @@ class LayerwiseOffloadManager:
target = self.get_target_with_name(name)
target.data = torch.empty((1,), device=self.device, dtype=meta["dtype"])
layer_buffers = self._layer_gpu_buffers.pop(layer_idx, None)
if layer_buffers is not None:
for dtype, buffer in layer_buffers.items():
self._release_gpu_buffer(dtype, buffer.numel(), buffer)
self._gpu_layers.discard(layer_idx)
@torch.compiler.disable
@@ -229,7 +257,9 @@ class LayerwiseOffloadManager:
torch.cuda.current_stream().wait_stream(self.copy_stream)
for layer_idx in list(self._gpu_layers):
self.release_layer(layer_idx)
self.release_layer(layer_idx, force=True)
self._gpu_buffer_pool.clear()
@torch.compiler.disable
def load_all_layers(self) -> None: