diff --git a/backend/memory_management.py b/backend/memory_management.py index c3d80edb..66143633 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -350,7 +350,7 @@ def bake_gguf_model(model): return model -def module_size(module, exclude_device=None, return_split=False): +def module_size(module, exclude_device=None, include_device=None, return_split=False): module_mem = 0 weight_mem = 0 weight_patterns = ['weight'] @@ -362,6 +362,10 @@ def module_size(module, exclude_device=None, return_split=False): if t.device == exclude_device: continue + if include_device is not None: + if t.device != include_device: + continue + element_size = t.element_size() if getattr(p, 'quant_type', None) in ['fp4', 'nf4']: @@ -439,12 +443,13 @@ class LoadedModel: self.memory_required = memory_required self.model_accelerated = False self.device = model.load_device + self.inclusive_memory = 0 + self.exclusive_memory = 0 - def model_memory(self): - return self.model.model_size() - - def model_memory_required(self, device=None): - return module_size(self.model.model, exclude_device=device) + def compute_inclusive_exclusive_memory(self): + self.inclusive_memory = module_size(self.model.model, include_device=self.device) + self.exclusive_memory = module_size(self.model.model, exclude_device=self.device) + return def model_load(self, model_gpu_memory_when_using_cpu_swap=-1): patch_model_to = None @@ -548,9 +553,6 @@ def unload_model_clones(model): if model.is_clone(current_loaded_models[i].model): to_unload = [i] + to_unload - if len(to_unload) > 0: - print(f"Reuse {len(to_unload)} loaded models") - for i in to_unload: current_loaded_models.pop(i).model_unload(avoid_model_moving=True) @@ -632,10 +634,13 @@ def load_models_gpu(models, memory_required=0): return - total_memory_required = {} for loaded_model in models_to_load: unload_model_clones(loaded_model.model) - total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required() + + total_memory_required = {} + for loaded_model in models_to_load: + loaded_model.compute_inclusive_exclusive_memory() + total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.exclusive_memory + loaded_model.inclusive_memory * 0.25 for device in total_memory_required: if device != torch.device("cpu"): @@ -652,16 +657,20 @@ def load_models_gpu(models, memory_required=0): model_gpu_memory_when_using_cpu_swap = -1 if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): - model_memory = loaded_model.model_memory_required(torch_dev) + model_require = loaded_model.exclusive_memory + previously_loaded = loaded_model.inclusive_memory + current_free_mem = get_free_memory(torch_dev) inference_memory = minimum_inference_memory() - estimated_remaining_memory = current_free_mem - model_memory - inference_memory + estimated_remaining_memory = current_free_mem - model_require - inference_memory - print(f"[Memory Management] Target: {loaded_model.model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_memory / (1024 * 1024):.2f} MB, Inference Require: {inference_memory / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="") + print(f"[Memory Management] Target: {loaded_model.model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_require / (1024 * 1024):.2f} MB, Previously Loaded: {previously_loaded / (1024 * 1024):.2f} MB, Inference Require: {inference_memory / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="") if estimated_remaining_memory < 0: vram_set_state = VRAMState.LOW_VRAM model_gpu_memory_when_using_cpu_swap = compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_memory) + if previously_loaded > 0: + model_gpu_memory_when_using_cpu_swap = previously_loaded if vram_set_state == VRAMState.NO_VRAM: model_gpu_memory_when_using_cpu_swap = 0