diff --git a/backend/memory_management.py b/backend/memory_management.py index eddf5224..88a0a006 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -441,9 +441,8 @@ def build_module_profile(model, model_gpu_memory_when_using_cpu_swap): class LoadedModel: - def __init__(self, model, memory_required): + def __init__(self, model): self.model = model - self.memory_required = memory_required self.model_accelerated = False self.device = model.load_device self.inclusive_memory = 0 @@ -607,16 +606,16 @@ def compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_mem return int(max(0, suggestion)) -def load_models_gpu(models, memory_required=0): +def load_models_gpu(models, memory_required=0, hard_memory_preservation=0): global vram_state execution_start_time = time.perf_counter() - extra_mem = max(minimum_inference_memory(), memory_required) + extra_mem = max(minimum_inference_memory(), memory_required + hard_memory_preservation) models_to_load = [] models_already_loaded = [] for x in models: - loaded_model = LoadedModel(x, memory_required=memory_required) + loaded_model = LoadedModel(x) if loaded_model in current_loaded_models: index = current_loaded_models.index(loaded_model) @@ -664,7 +663,7 @@ def load_models_gpu(models, memory_required=0): previously_loaded = loaded_model.inclusive_memory current_free_mem = get_free_memory(torch_dev) - inference_memory = minimum_inference_memory() + inference_memory = minimum_inference_memory() + hard_memory_preservation estimated_remaining_memory = current_free_mem - model_require - inference_memory print(f"[Memory Management] Target: {loaded_model.model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_require / (1024 * 1024):.2f} MB, Previously Loaded: {previously_loaded / (1024 * 1024):.2f} MB, Inference Require: {inference_memory / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="") diff --git a/backend/sampling/sampling_function.py b/backend/sampling/sampling_function.py index 27cb1308..4f9ec039 100644 --- a/backend/sampling/sampling_function.py +++ b/backend/sampling/sampling_function.py @@ -382,7 +382,9 @@ def sampling_prepare(unet, x): memory_management.load_models_gpu( models=[unet] + additional_model_patchers, - memory_required=unet_inference_memory + additional_inference_memory) + memory_required=unet_inference_memory, + hard_memory_preservation=additional_inference_memory + ) if unet.has_online_lora(): utils.nested_move_to_device(unet.lora_patches, device=unet.current_device, dtype=unet.model.computation_dtype)