diff --git a/backend/memory_management.py b/backend/memory_management.py index e1c1ffba..15749987 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -322,24 +322,24 @@ class LoadedModel: def model_memory_required(self, device): return module_size(self.model.model, exclude_device=device) - def model_load(self, async_kept_memory=-1): + def model_load(self, model_gpu_memory_when_using_cpu_swap=-1): patch_model_to = None - disable_async_load = async_kept_memory < 0 + do_not_need_cpu_swap = model_gpu_memory_when_using_cpu_swap < 0 - if disable_async_load: + if do_not_need_cpu_swap: patch_model_to = self.device self.model.model_patches_to(self.device) self.model.model_patches_to(self.model.model_dtype()) try: - self.real_model = self.model.patch_model(device_to=patch_model_to) # TODO: do something with loras and offloading to CPU + self.real_model = self.model.patch_model(device_to=patch_model_to) except Exception as e: self.model.unpatch_model(self.model.offload_device) self.model_unload() raise e - if not disable_async_load: + if not do_not_need_cpu_swap: real_async_memory = 0 mem_counter = 0 for m in self.real_model.modules(): @@ -347,7 +347,7 @@ class LoadedModel: m.prev_parameters_manual_cast = m.parameters_manual_cast m.parameters_manual_cast = True module_mem = module_size(m) - if mem_counter + module_mem < async_kept_memory: + if mem_counter + module_mem < model_gpu_memory_when_using_cpu_swap: m.to(self.device) mem_counter += module_mem else: @@ -438,6 +438,18 @@ def free_memory(memory_required, device, keep_loaded=[]): soft_empty_cache() +def compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_memory): + maximum_memory_available = current_free_mem - inference_memory + + k_1GB = float(inference_memory / (1024 * 1024 * 1024)) + k_1GB = max(0.0, min(1.0, k_1GB)) + + adaptive_safe_factor = 1.0 - 0.23 * k_1GB + suggestion = maximum_memory_available * adaptive_safe_factor + + return int(max(0, suggestion)) + + def load_models_gpu(models, memory_required=0): global vram_state @@ -489,7 +501,7 @@ def load_models_gpu(models, memory_required=0): else: vram_set_state = vram_state - async_kept_memory = -1 + model_gpu_memory_when_using_cpu_swap = -1 if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_memory = loaded_model.model_memory_required(torch_dev) @@ -504,13 +516,12 @@ def load_models_gpu(models, memory_required=0): if estimated_remaining_memory < 0: vram_set_state = VRAMState.LOW_VRAM - async_kept_memory = (current_free_mem - inference_memory) / 1.3 - async_kept_memory = int(max(0, async_kept_memory)) + model_gpu_memory_when_using_cpu_swap = compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_memory) if vram_set_state == VRAMState.NO_VRAM: - async_kept_memory = 0 + model_gpu_memory_when_using_cpu_swap = 0 - loaded_model.model_load(async_kept_memory) + loaded_model.model_load(model_gpu_memory_when_using_cpu_swap) current_loaded_models.insert(0, loaded_model) moving_time = time.perf_counter() - execution_start_time