add a way to empty cuda cache on the fly

This commit is contained in:
layerdiffusion
2024-08-22 10:06:39 -07:00
parent 64b5ce49d1
commit 14eac6f2cf
3 changed files with 14 additions and 2 deletions

View File

@@ -1091,8 +1091,11 @@ def can_install_bnb():
return False
signal_empty_cache = True
def soft_empty_cache(force=False):
global cpu_state
global cpu_state, signal_empty_cache
if cpu_state == CPUState.MPS:
torch.mps.empty_cache()
elif is_intel_xpu():
@@ -1101,6 +1104,8 @@ def soft_empty_cache(force=False):
if force or is_nvidia(): # This seems to make things worse on ROCm so I only do it for cuda
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
signal_empty_cache = False
return
def unload_all_models():

View File

@@ -3,7 +3,7 @@
import torch
import bitsandbytes as bnb
from backend import utils
from backend import utils, memory_management
from bitsandbytes.nn.modules import Params4bit, QuantState
from bitsandbytes.functional import dequantize_4bit
@@ -50,6 +50,10 @@ def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantSta
class ForgeParams4bit(Params4bit):
def _quantize(self, device):
memory_management.signal_empty_cache = True
return super()._quantize(device)
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None and device.type == "cuda" and not self.bnb_quantized:

View File

@@ -187,6 +187,9 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
if memory_management.signal_empty_cache:
memory_management.soft_empty_cache()
free_memory = memory_management.get_free_memory(x_in.device)
if (not args.disable_gpu_warning) and x_in.device.type == 'cuda':