mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 01:49:07 +00:00
In-progress commit on making flipflop async weight streaming native, made loaded partially/loaded completely log messages have labels because having to memorize their meaning for dev work is annoying
This commit is contained in:
@@ -605,7 +605,27 @@ class ModelPatcher:
|
||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||
|
||||
def supports_flipflop(self):
|
||||
return hasattr(self.model.diffusion_model, "flipflop")
|
||||
# flipflop requires diffusion_model, explicit flipflop support, NVIDIA CUDA streams, and loading/offloading VRAM
|
||||
if not hasattr(self.model, "diffusion_model"):
|
||||
return False
|
||||
if not hasattr(self.model.diffusion_model, "flipflop"):
|
||||
return False
|
||||
if not comfy.model_management.is_nvidia():
|
||||
return False
|
||||
if comfy.model_management.vram_state in (comfy.model_management.VRAMState.HIGH_VRAM, comfy.model_management.VRAMState.SHARED):
|
||||
return False
|
||||
return True
|
||||
|
||||
def init_flipflop(self):
|
||||
if not self.supports_flipflop():
|
||||
return
|
||||
# figure out how many b
|
||||
self.model.diffusion_model.setup_flipflop_holders(self.model_options["flipflop_block_percentage"])
|
||||
|
||||
def clean_flipflop(self):
|
||||
if not self.supports_flipflop():
|
||||
return
|
||||
self.model.diffusion_model.clean_flipflop_holders()
|
||||
|
||||
def _load_list(self):
|
||||
loading = []
|
||||
@@ -628,6 +648,9 @@ class ModelPatcher:
|
||||
mem_counter = 0
|
||||
patch_counter = 0
|
||||
lowvram_counter = 0
|
||||
lowvram_mem_counter = 0
|
||||
if self.supports_flipflop():
|
||||
...
|
||||
loading = self._load_list()
|
||||
|
||||
load_completely = []
|
||||
@@ -647,6 +670,7 @@ class ModelPatcher:
|
||||
if mem_counter + module_mem >= lowvram_model_memory:
|
||||
lowvram_weight = True
|
||||
lowvram_counter += 1
|
||||
lowvram_mem_counter += module_mem
|
||||
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||
continue
|
||||
|
||||
@@ -709,10 +733,10 @@ class ModelPatcher:
|
||||
x[2].to(device_to)
|
||||
|
||||
if lowvram_counter > 0:
|
||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||
logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable memory, {mem_counter / (1024 * 1024):.2f} MB loaded, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}")
|
||||
self.model.model_lowvram = True
|
||||
else:
|
||||
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
logging.info(f"loaded completely; {lowvram_model_memory / (1024 * 1024):.2f} MB usable memory, {mem_counter / (1024 * 1024):.2f} MB loaded, full load: {full_load}")
|
||||
self.model.model_lowvram = False
|
||||
if full_load:
|
||||
self.model.to(device_to)
|
||||
|
||||
Reference in New Issue
Block a user