diff --git a/ldm_patched/modules/model_management.py b/ldm_patched/modules/model_management.py index eb225392..e8441802 100644 --- a/ldm_patched/modules/model_management.py +++ b/ldm_patched/modules/model_management.py @@ -268,17 +268,23 @@ print("VAE dtype:", VAE_DTYPE) current_loaded_models = [] -def module_size(module): +def module_size(module, exclude_device=None): module_mem = 0 sd = module.state_dict() for k in sd: t = sd[k] + + if exclude_device is not None: + if t.device == exclude_device: + continue + module_mem += t.nelement() * t.element_size() return module_mem class LoadedModel: - def __init__(self, model): + def __init__(self, model, memory_required): self.model = model + self.memory_required = memory_required self.model_accelerated = False self.device = model.load_device @@ -286,10 +292,7 @@ class LoadedModel: return self.model.model_size() def model_memory_required(self, device): - if device == self.model.current_device: - return 0 - else: - return self.model_memory() + return module_size(self.model.model, exclude_device=device) def model_load(self, async_kept_memory=-1): patch_model_to = None @@ -324,7 +327,9 @@ class LoadedModel: real_kept_memory += module_mem else: real_async_memory += module_mem - m._apply(lambda x: x.pin_memory()) + m.to(self.model.offload_device) + if is_device_cpu(self.model.offload_device): + m._apply(lambda x: x.pin_memory()) elif hasattr(m, "weight"): m.to(self.device) mem_counter += module_size(m) @@ -339,7 +344,7 @@ class LoadedModel: return self.real_model - def model_unload(self): + def model_unload(self, avoid_model_moving=False): if self.model_accelerated: for m in self.real_model.modules(): if hasattr(m, "prev_ldm_patched_cast_weights"): @@ -348,11 +353,14 @@ class LoadedModel: self.model_accelerated = False - self.model.unpatch_model(self.model.offload_device) - self.model.model_patches_to(self.model.offload_device) + if avoid_model_moving: + self.model.unpatch_model() + else: + self.model.unpatch_model(self.model.offload_device) + self.model.model_patches_to(self.model.offload_device) def __eq__(self, other): - return self.model is other.model + return self.model is other.model and self.memory_required == other.memory_required def minimum_inference_memory(): return (1024 * 1024 * 1024) @@ -363,9 +371,10 @@ def unload_model_clones(model): if model.is_clone(current_loaded_models[i].model): to_unload = [i] + to_unload + print(f"Reuse {len(to_unload)} loaded models") + for i in to_unload: - print("unload clone", i) - current_loaded_models.pop(i).model_unload() + current_loaded_models.pop(i).model_unload(avoid_model_moving=True) def free_memory(memory_required, device, keep_loaded=[]): unloaded_model = False @@ -400,7 +409,7 @@ def load_models_gpu(models, memory_required=0): models_to_load = [] models_already_loaded = [] for x in models: - loaded_model = LoadedModel(x) + loaded_model = LoadedModel(x, memory_required=memory_required) if loaded_model in current_loaded_models: index = current_loaded_models.index(loaded_model)