mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 18:51:31 +00:00
add a way to empty cuda cache on the fly
This commit is contained in:
@@ -1091,8 +1091,11 @@ def can_install_bnb():
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
signal_empty_cache = True
|
||||||
|
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
global cpu_state
|
global cpu_state, signal_empty_cache
|
||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
elif is_intel_xpu():
|
elif is_intel_xpu():
|
||||||
@@ -1101,6 +1104,8 @@ def soft_empty_cache(force=False):
|
|||||||
if force or is_nvidia(): # This seems to make things worse on ROCm so I only do it for cuda
|
if force or is_nvidia(): # This seems to make things worse on ROCm so I only do it for cuda
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
signal_empty_cache = False
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def unload_all_models():
|
def unload_all_models():
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
from backend import utils
|
from backend import utils, memory_management
|
||||||
from bitsandbytes.nn.modules import Params4bit, QuantState
|
from bitsandbytes.nn.modules import Params4bit, QuantState
|
||||||
from bitsandbytes.functional import dequantize_4bit
|
from bitsandbytes.functional import dequantize_4bit
|
||||||
|
|
||||||
@@ -50,6 +50,10 @@ def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantSta
|
|||||||
|
|
||||||
|
|
||||||
class ForgeParams4bit(Params4bit):
|
class ForgeParams4bit(Params4bit):
|
||||||
|
def _quantize(self, device):
|
||||||
|
memory_management.signal_empty_cache = True
|
||||||
|
return super()._quantize(device)
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
def to(self, *args, **kwargs):
|
||||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||||
if device is not None and device.type == "cuda" and not self.bnb_quantized:
|
if device is not None and device.type == "cuda" and not self.bnb_quantized:
|
||||||
|
|||||||
@@ -187,6 +187,9 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
|||||||
to_batch_temp.reverse()
|
to_batch_temp.reverse()
|
||||||
to_batch = to_batch_temp[:1]
|
to_batch = to_batch_temp[:1]
|
||||||
|
|
||||||
|
if memory_management.signal_empty_cache:
|
||||||
|
memory_management.soft_empty_cache()
|
||||||
|
|
||||||
free_memory = memory_management.get_free_memory(x_in.device)
|
free_memory = memory_management.get_free_memory(x_in.device)
|
||||||
|
|
||||||
if (not args.disable_gpu_warning) and x_in.device.type == 'cuda':
|
if (not args.disable_gpu_warning) and x_in.device.type == 'cuda':
|
||||||
|
|||||||
Reference in New Issue
Block a user