Rework unload system

Previous repeated loading (on cn or other extensions) is fixed. ControlNet saves about 0.7 to 1.1 seconds on my two device when batch count > 1. 

8GB VRAM can use SDXL at resolution 6144x6144 now, out of the box, without tiled diffusion or other things. 

(the max resolution on Automatic1111 txt2img UI is 2048 but one can highres fix to try 6144 or even 8192)
This commit is contained in:
lllyasviel
2024-02-23 12:58:09 -08:00
committed by GitHub
parent 19473b1a26
commit df12dde12e

View File

@@ -268,17 +268,23 @@ print("VAE dtype:", VAE_DTYPE)
current_loaded_models = []
def module_size(module):
def module_size(module, exclude_device=None):
module_mem = 0
sd = module.state_dict()
for k in sd:
t = sd[k]
if exclude_device is not None:
if t.device == exclude_device:
continue
module_mem += t.nelement() * t.element_size()
return module_mem
class LoadedModel:
def __init__(self, model):
def __init__(self, model, memory_required):
self.model = model
self.memory_required = memory_required
self.model_accelerated = False
self.device = model.load_device
@@ -286,10 +292,7 @@ class LoadedModel:
return self.model.model_size()
def model_memory_required(self, device):
if device == self.model.current_device:
return 0
else:
return self.model_memory()
return module_size(self.model.model, exclude_device=device)
def model_load(self, async_kept_memory=-1):
patch_model_to = None
@@ -324,7 +327,9 @@ class LoadedModel:
real_kept_memory += module_mem
else:
real_async_memory += module_mem
m._apply(lambda x: x.pin_memory())
m.to(self.model.offload_device)
if is_device_cpu(self.model.offload_device):
m._apply(lambda x: x.pin_memory())
elif hasattr(m, "weight"):
m.to(self.device)
mem_counter += module_size(m)
@@ -339,7 +344,7 @@ class LoadedModel:
return self.real_model
def model_unload(self):
def model_unload(self, avoid_model_moving=False):
if self.model_accelerated:
for m in self.real_model.modules():
if hasattr(m, "prev_ldm_patched_cast_weights"):
@@ -348,11 +353,14 @@ class LoadedModel:
self.model_accelerated = False
self.model.unpatch_model(self.model.offload_device)
self.model.model_patches_to(self.model.offload_device)
if avoid_model_moving:
self.model.unpatch_model()
else:
self.model.unpatch_model(self.model.offload_device)
self.model.model_patches_to(self.model.offload_device)
def __eq__(self, other):
return self.model is other.model
return self.model is other.model and self.memory_required == other.memory_required
def minimum_inference_memory():
return (1024 * 1024 * 1024)
@@ -363,9 +371,10 @@ def unload_model_clones(model):
if model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
print(f"Reuse {len(to_unload)} loaded models")
for i in to_unload:
print("unload clone", i)
current_loaded_models.pop(i).model_unload()
current_loaded_models.pop(i).model_unload(avoid_model_moving=True)
def free_memory(memory_required, device, keep_loaded=[]):
unloaded_model = False
@@ -400,7 +409,7 @@ def load_models_gpu(models, memory_required=0):
models_to_load = []
models_already_loaded = []
for x in models:
loaded_model = LoadedModel(x)
loaded_model = LoadedModel(x, memory_required=memory_required)
if loaded_model in current_loaded_models:
index = current_loaded_models.index(loaded_model)