From df12dde12efbc922054521dd0b39ca63dfdd8a7e Mon Sep 17 00:00:00 2001 From: lllyasviel <19834515+lllyasviel@users.noreply.github.com> Date: Fri, 23 Feb 2024 12:58:09 -0800 Subject: [PATCH] Rework unload system Previous repeated loading (on cn or other extensions) is fixed. ControlNet saves about 0.7 to 1.1 seconds on my two device when batch count > 1. 8GB VRAM can use SDXL at resolution 6144x6144 now, out of the box, without tiled diffusion or other things. (the max resolution on Automatic1111 txt2img UI is 2048 but one can highres fix to try 6144 or even 8192) --- ldm_patched/modules/model_management.py | 37 +++++++++++++++---------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/ldm_patched/modules/model_management.py b/ldm_patched/modules/model_management.py index eb225392..e8441802 100644 --- a/ldm_patched/modules/model_management.py +++ b/ldm_patched/modules/model_management.py @@ -268,17 +268,23 @@ print("VAE dtype:", VAE_DTYPE) current_loaded_models = [] -def module_size(module): +def module_size(module, exclude_device=None): module_mem = 0 sd = module.state_dict() for k in sd: t = sd[k] + + if exclude_device is not None: + if t.device == exclude_device: + continue + module_mem += t.nelement() * t.element_size() return module_mem class LoadedModel: - def __init__(self, model): + def __init__(self, model, memory_required): self.model = model + self.memory_required = memory_required self.model_accelerated = False self.device = model.load_device @@ -286,10 +292,7 @@ class LoadedModel: return self.model.model_size() def model_memory_required(self, device): - if device == self.model.current_device: - return 0 - else: - return self.model_memory() + return module_size(self.model.model, exclude_device=device) def model_load(self, async_kept_memory=-1): patch_model_to = None @@ -324,7 +327,9 @@ class LoadedModel: real_kept_memory += module_mem else: real_async_memory += module_mem - m._apply(lambda x: x.pin_memory()) + m.to(self.model.offload_device) + if is_device_cpu(self.model.offload_device): + m._apply(lambda x: x.pin_memory()) elif hasattr(m, "weight"): m.to(self.device) mem_counter += module_size(m) @@ -339,7 +344,7 @@ class LoadedModel: return self.real_model - def model_unload(self): + def model_unload(self, avoid_model_moving=False): if self.model_accelerated: for m in self.real_model.modules(): if hasattr(m, "prev_ldm_patched_cast_weights"): @@ -348,11 +353,14 @@ class LoadedModel: self.model_accelerated = False - self.model.unpatch_model(self.model.offload_device) - self.model.model_patches_to(self.model.offload_device) + if avoid_model_moving: + self.model.unpatch_model() + else: + self.model.unpatch_model(self.model.offload_device) + self.model.model_patches_to(self.model.offload_device) def __eq__(self, other): - return self.model is other.model + return self.model is other.model and self.memory_required == other.memory_required def minimum_inference_memory(): return (1024 * 1024 * 1024) @@ -363,9 +371,10 @@ def unload_model_clones(model): if model.is_clone(current_loaded_models[i].model): to_unload = [i] + to_unload + print(f"Reuse {len(to_unload)} loaded models") + for i in to_unload: - print("unload clone", i) - current_loaded_models.pop(i).model_unload() + current_loaded_models.pop(i).model_unload(avoid_model_moving=True) def free_memory(memory_required, device, keep_loaded=[]): unloaded_model = False @@ -400,7 +409,7 @@ def load_models_gpu(models, memory_required=0): models_to_load = [] models_already_loaded = [] for x in models: - loaded_model = LoadedModel(x) + loaded_model = LoadedModel(x, memory_required=memory_required) if loaded_model in current_loaded_models: index = current_loaded_models.index(loaded_model)