From 14eac6f2cfc22d25f0abeb25d747272bb0dc4c07 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Thu, 22 Aug 2024 10:06:39 -0700 Subject: [PATCH] add a way to empty cuda cache on the fly --- backend/memory_management.py | 7 ++++++- backend/operations_bnb.py | 6 +++++- backend/sampling/sampling_function.py | 3 +++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/backend/memory_management.py b/backend/memory_management.py index dd64da6b..9f1321d7 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -1091,8 +1091,11 @@ def can_install_bnb(): return False +signal_empty_cache = True + + def soft_empty_cache(force=False): - global cpu_state + global cpu_state, signal_empty_cache if cpu_state == CPUState.MPS: torch.mps.empty_cache() elif is_intel_xpu(): @@ -1101,6 +1104,8 @@ def soft_empty_cache(force=False): if force or is_nvidia(): # This seems to make things worse on ROCm so I only do it for cuda torch.cuda.empty_cache() torch.cuda.ipc_collect() + signal_empty_cache = False + return def unload_all_models(): diff --git a/backend/operations_bnb.py b/backend/operations_bnb.py index 654776ca..eca619aa 100644 --- a/backend/operations_bnb.py +++ b/backend/operations_bnb.py @@ -3,7 +3,7 @@ import torch import bitsandbytes as bnb -from backend import utils +from backend import utils, memory_management from bitsandbytes.nn.modules import Params4bit, QuantState from bitsandbytes.functional import dequantize_4bit @@ -50,6 +50,10 @@ def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantSta class ForgeParams4bit(Params4bit): + def _quantize(self, device): + memory_management.signal_empty_cache = True + return super()._quantize(device) + def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) if device is not None and device.type == "cuda" and not self.bnb_quantized: diff --git a/backend/sampling/sampling_function.py b/backend/sampling/sampling_function.py index 9ae54d36..1c32a09b 100644 --- a/backend/sampling/sampling_function.py +++ b/backend/sampling/sampling_function.py @@ -187,6 +187,9 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): to_batch_temp.reverse() to_batch = to_batch_temp[:1] + if memory_management.signal_empty_cache: + memory_management.soft_empty_cache() + free_memory = memory_management.get_free_memory(x_in.device) if (not args.disable_gpu_warning) and x_in.device.type == 'cuda':