diff --git a/comfy/model_management.py b/comfy/model_management.py index 2167f81bf..cd035f017 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -19,7 +19,8 @@ import psutil import logging from enum import Enum -from comfy.cli_args import args, PerformanceFeature +from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram +import threading import torch import sys import platform @@ -650,7 +651,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_ soft_empty_cache() return unloaded_models -def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): +def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): cleanup_models_gc() global vram_state @@ -746,8 +747,25 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu current_loaded_models.insert(0, loaded_model) return -def load_model_gpu(model): - return load_models_gpu([model]) +def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load): + with torch.inference_mode(): + load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load) + soft_empty_cache() + +def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): + #Deliberately load models outside of the Aimdo mempool so they can be retained accross + #nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are + #thread local. So exploit that to escape context + if enables_dynamic_vram(): + t = threading.Thread( + target=load_models_gpu_thread, + args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load) + ) + t.start() + t.join() + else: + load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights, + minimum_memory_required=minimum_memory_required, force_full_load=force_full_load) def loaded_models(only_currently_used=False): output = [] @@ -1717,9 +1735,6 @@ def debug_memory_summary(): return torch.cuda.memory.memory_summary() return "" -#TODO: might be cleaner to put this somewhere else -import threading - class InterruptProcessingException(Exception): pass diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b70c031bf..cdf289395 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1597,7 +1597,7 @@ class ModelPatcherDynamic(ModelPatcher): if unpatch_weights: self.partially_unload_ram(1e32) - self.partially_unload(None) + self.partially_unload(None, 1e32) def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): assert not force_patch_weights #See above