mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-09 01:19:58 +00:00
add a way to empty cuda cache on the fly
This commit is contained in:
@@ -1091,8 +1091,11 @@ def can_install_bnb():
|
||||
return False
|
||||
|
||||
|
||||
signal_empty_cache = True
|
||||
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
global cpu_state
|
||||
global cpu_state, signal_empty_cache
|
||||
if cpu_state == CPUState.MPS:
|
||||
torch.mps.empty_cache()
|
||||
elif is_intel_xpu():
|
||||
@@ -1101,6 +1104,8 @@ def soft_empty_cache(force=False):
|
||||
if force or is_nvidia(): # This seems to make things worse on ROCm so I only do it for cuda
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
signal_empty_cache = False
|
||||
return
|
||||
|
||||
|
||||
def unload_all_models():
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import torch
|
||||
import bitsandbytes as bnb
|
||||
|
||||
from backend import utils
|
||||
from backend import utils, memory_management
|
||||
from bitsandbytes.nn.modules import Params4bit, QuantState
|
||||
from bitsandbytes.functional import dequantize_4bit
|
||||
|
||||
@@ -50,6 +50,10 @@ def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantSta
|
||||
|
||||
|
||||
class ForgeParams4bit(Params4bit):
|
||||
def _quantize(self, device):
|
||||
memory_management.signal_empty_cache = True
|
||||
return super()._quantize(device)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||
if device is not None and device.type == "cuda" and not self.bnb_quantized:
|
||||
|
||||
@@ -187,6 +187,9 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
if memory_management.signal_empty_cache:
|
||||
memory_management.soft_empty_cache()
|
||||
|
||||
free_memory = memory_management.get_free_memory(x_in.device)
|
||||
|
||||
if (not args.disable_gpu_warning) and x_in.device.type == 'cuda':
|
||||
|
||||
Reference in New Issue
Block a user