mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 14:19:57 +00:00
Compare commits
9 Commits
luke-mino-
...
mark-dtype
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d62334eb9e | ||
|
|
ac4a943ff3 | ||
|
|
8811db52db | ||
|
|
0a7446ade4 | ||
|
|
9b85cf9558 | ||
|
|
d531e3fb2a | ||
|
|
eb011733b6 | ||
|
|
ac6513e142 | ||
|
|
b6ddc590ed |
@@ -796,6 +796,8 @@ def archive_model_dtypes(model):
|
||||
for name, module in model.named_modules():
|
||||
for param_name, param in module.named_parameters(recurse=False):
|
||||
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
|
||||
for buf_name, buf in module.named_buffers(recurse=False):
|
||||
setattr(module, f"{buf_name}_comfy_model_dtype", buf.dtype)
|
||||
|
||||
|
||||
def cleanup_models():
|
||||
@@ -828,11 +830,14 @@ def unet_offload_device():
|
||||
return torch.device("cpu")
|
||||
|
||||
def unet_inital_load_device(parameters, dtype):
|
||||
cpu_dev = torch.device("cpu")
|
||||
if comfy.memory_management.aimdo_enabled:
|
||||
return cpu_dev
|
||||
|
||||
torch_dev = get_torch_device()
|
||||
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
|
||||
return torch_dev
|
||||
|
||||
cpu_dev = torch.device("cpu")
|
||||
if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM:
|
||||
return cpu_dev
|
||||
|
||||
@@ -840,7 +845,7 @@ def unet_inital_load_device(parameters, dtype):
|
||||
|
||||
mem_dev = get_free_memory(torch_dev)
|
||||
mem_cpu = get_free_memory(cpu_dev)
|
||||
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled:
|
||||
if mem_dev > mem_cpu and model_size < mem_dev:
|
||||
return torch_dev
|
||||
else:
|
||||
return cpu_dev
|
||||
@@ -943,6 +948,9 @@ def text_encoder_device():
|
||||
return torch.device("cpu")
|
||||
|
||||
def text_encoder_initial_device(load_device, offload_device, model_size=0):
|
||||
if comfy.memory_management.aimdo_enabled:
|
||||
return offload_device
|
||||
|
||||
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
|
||||
return offload_device
|
||||
|
||||
|
||||
@@ -241,6 +241,7 @@ class ModelPatcher:
|
||||
|
||||
self.patches = {}
|
||||
self.backup = {}
|
||||
self.backup_buffers = {}
|
||||
self.object_patches = {}
|
||||
self.object_patches_backup = {}
|
||||
self.weight_wrapper_patches = {}
|
||||
@@ -306,10 +307,16 @@ class ModelPatcher:
|
||||
return self.model.lowvram_patch_counter
|
||||
|
||||
def get_free_memory(self, device):
|
||||
return comfy.model_management.get_free_memory(device)
|
||||
#Prioritize batching (incl. CFG/conds etc) over keeping the model resident. In
|
||||
#the vast majority of setups a little bit of offloading on the giant model more
|
||||
#than pays for CFG. So return everything both torch and Aimdo could give us
|
||||
aimdo_mem = 0
|
||||
if comfy.memory_management.aimdo_enabled:
|
||||
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze()
|
||||
return comfy.model_management.get_free_memory(device) + aimdo_mem
|
||||
|
||||
def get_clone_model_override(self):
|
||||
return self.model, (self.backup, self.object_patches_backup, self.pinned)
|
||||
return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned)
|
||||
|
||||
def clone(self, disable_dynamic=False, model_override=None):
|
||||
class_ = self.__class__
|
||||
@@ -336,7 +343,7 @@ class ModelPatcher:
|
||||
|
||||
n.force_cast_weights = self.force_cast_weights
|
||||
|
||||
n.backup, n.object_patches_backup, n.pinned = model_override[1]
|
||||
n.backup, n.backup_buffers, n.object_patches_backup, n.pinned = model_override[1]
|
||||
|
||||
# attachments
|
||||
n.attachments = {}
|
||||
@@ -698,7 +705,7 @@ class ModelPatcher:
|
||||
for key in list(self.pinned):
|
||||
self.unpin_weight(key)
|
||||
|
||||
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
|
||||
def _load_list(self, for_dynamic=False, default_device=None):
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
default = False
|
||||
@@ -726,8 +733,13 @@ class ModelPatcher:
|
||||
return 0
|
||||
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
|
||||
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
|
||||
prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else ()
|
||||
loading.append(prepend + (module_offload_mem, module_mem, n, m, params))
|
||||
# Dynamic: small weights (<64KB) first, then larger weights prioritized by size.
|
||||
# Non-dynamic: prioritize by module offload cost.
|
||||
if for_dynamic:
|
||||
sort_criteria = (module_offload_mem >= 64 * 1024, -module_offload_mem)
|
||||
else:
|
||||
sort_criteria = (module_offload_mem,)
|
||||
loading.append(sort_criteria + (module_mem, n, m, params))
|
||||
return loading
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
@@ -1459,12 +1471,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
vbar = self._vbar_get()
|
||||
return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory
|
||||
|
||||
def get_free_memory(self, device):
|
||||
#NOTE: on high condition / batch counts, estimate should have already vacated
|
||||
#all non-dynamic models so this is safe even if its not 100% true that this
|
||||
#would all be avaiable for inference use.
|
||||
return comfy.model_management.get_total_memory(device) - self.model_size()
|
||||
|
||||
#Pinning is deferred to ops time. Assert against this API to avoid pin leaks.
|
||||
|
||||
def pin_weight_to_device(self, key):
|
||||
@@ -1507,11 +1513,11 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
if vbar is not None:
|
||||
vbar.prioritize()
|
||||
|
||||
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
|
||||
loading.sort(reverse=True)
|
||||
loading = self._load_list(for_dynamic=True, default_device=device_to)
|
||||
loading.sort()
|
||||
|
||||
for x in loading:
|
||||
_, _, _, n, m, params = x
|
||||
*_, module_mem, n, m, params = x
|
||||
|
||||
def set_dirty(item, dirty):
|
||||
if dirty or not hasattr(item, "_v_signature"):
|
||||
@@ -1579,11 +1585,22 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
if key not in self.backup:
|
||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False)
|
||||
comfy.utils.set_attr_param(self.model, key, weight.to(device=device_to))
|
||||
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
|
||||
model_dtype = getattr(m, param + "_comfy_model_dtype", None)
|
||||
casted_weight = weight.to(dtype=model_dtype, device=device_to)
|
||||
comfy.utils.set_attr_param(self.model, key, casted_weight)
|
||||
self.model.model_loaded_weight_memory += casted_weight.numel() * casted_weight.element_size()
|
||||
|
||||
move_weight_functions(m, device_to)
|
||||
|
||||
for key, buf in self.model.named_buffers(recurse=True):
|
||||
if key not in self.backup_buffers:
|
||||
self.backup_buffers[key] = buf
|
||||
module, buf_name = comfy.utils.resolve_attr(self.model, key)
|
||||
model_dtype = getattr(module, buf_name + "_comfy_model_dtype", None)
|
||||
casted_buf = buf.to(dtype=model_dtype, device=device_to)
|
||||
comfy.utils.set_attr_buffer(self.model, key, casted_buf)
|
||||
self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size()
|
||||
|
||||
force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else ""
|
||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
|
||||
|
||||
@@ -1607,15 +1624,17 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
for key in list(self.backup.keys()):
|
||||
bk = self.backup.pop(key)
|
||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||
for key in list(self.backup_buffers.keys()):
|
||||
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
|
||||
freed += self.model.model_loaded_weight_memory
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
return freed
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
|
||||
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
|
||||
for x in loading:
|
||||
_, _, _, _, m, _ = x
|
||||
*_, m, _ = x
|
||||
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
||||
if ram_to_unload <= 0:
|
||||
return
|
||||
|
||||
@@ -269,8 +269,8 @@ def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||
return
|
||||
os, weight_a, bias_a = offload_stream
|
||||
device=None
|
||||
#FIXME: This is not good RTTI
|
||||
if not isinstance(weight_a, torch.Tensor):
|
||||
#FIXME: This is really bad RTTI
|
||||
if weight_a is not None and not isinstance(weight_a, torch.Tensor):
|
||||
comfy_aimdo.model_vbar.vbar_unpin(s._v)
|
||||
device = weight_a
|
||||
if os is None:
|
||||
|
||||
@@ -428,7 +428,7 @@ class CLIP:
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
|
||||
self.load_model()
|
||||
self.load_model(tokens)
|
||||
self.cond_stage_model.set_clip_options({"layer": None})
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
||||
|
||||
@@ -869,20 +869,31 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
||||
|
||||
ATTR_UNSET={}
|
||||
|
||||
def set_attr(obj, attr, value):
|
||||
def resolve_attr(obj, attr):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
prev = getattr(obj, attrs[-1], ATTR_UNSET)
|
||||
return obj, attrs[-1]
|
||||
|
||||
def set_attr(obj, attr, value):
|
||||
obj, name = resolve_attr(obj, attr)
|
||||
prev = getattr(obj, name, ATTR_UNSET)
|
||||
if value is ATTR_UNSET:
|
||||
delattr(obj, attrs[-1])
|
||||
delattr(obj, name)
|
||||
else:
|
||||
setattr(obj, attrs[-1], value)
|
||||
setattr(obj, name, value)
|
||||
return prev
|
||||
|
||||
def set_attr_param(obj, attr, value):
|
||||
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
||||
|
||||
def set_attr_buffer(obj, attr, value):
|
||||
obj, name = resolve_attr(obj, attr)
|
||||
prev = getattr(obj, name, ATTR_UNSET)
|
||||
persistent = name not in getattr(obj, "_non_persistent_buffers_set", set())
|
||||
obj.register_buffer(name, value, persistent=persistent)
|
||||
return prev
|
||||
|
||||
def copy_to_param(obj, attr, value):
|
||||
# inplace update tensor instead of replacing it
|
||||
attrs = attr.split(".")
|
||||
|
||||
@@ -401,6 +401,7 @@ class VideoFromComponents(VideoInput):
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None,
|
||||
):
|
||||
"""Save the video to a file path or BytesIO buffer."""
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
@@ -408,6 +409,10 @@ class VideoFromComponents(VideoInput):
|
||||
extra_kwargs = {}
|
||||
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
|
||||
extra_kwargs["format"] = format.value
|
||||
elif isinstance(path, io.BytesIO):
|
||||
# BytesIO has no file extension, so av.open can't infer the format.
|
||||
# Default to mp4 since that's the only supported format anyway.
|
||||
extra_kwargs["format"] = "mp4"
|
||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
|
||||
# Add metadata before writing any streams
|
||||
if metadata is not None:
|
||||
|
||||
@@ -1240,6 +1240,19 @@ class BoundingBox(ComfyTypeIO):
|
||||
return d
|
||||
|
||||
|
||||
@comfytype(io_type="CURVE")
|
||||
class Curve(ComfyTypeIO):
|
||||
CurvePoint = tuple[float, float]
|
||||
Type = list[CurvePoint]
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, default: list[tuple[float, float]]=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||
if default is None:
|
||||
self.default = [(0.0, 0.0), (1.0, 1.0)]
|
||||
|
||||
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
||||
@@ -2226,5 +2239,6 @@ __all__ = [
|
||||
"PriceBadgeDepends",
|
||||
"PriceBadge",
|
||||
"BoundingBox",
|
||||
"Curve",
|
||||
"NodeReplace",
|
||||
]
|
||||
|
||||
14
execution.py
14
execution.py
@@ -876,12 +876,14 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
# Unwraps values wrapped in __value__ key. This is used to pass
|
||||
# list widget value to execution, as by default list value is
|
||||
# reserved to represent the connection between nodes.
|
||||
if isinstance(val, dict) and "__value__" in val:
|
||||
val = val["__value__"]
|
||||
inputs[x] = val
|
||||
# Unwraps values wrapped in __value__ key or typed wrapper.
|
||||
# This is used to pass list widget values to execution,
|
||||
# as by default list value is reserved to represent the
|
||||
# connection between nodes.
|
||||
if isinstance(val, dict):
|
||||
if "__value__" in val:
|
||||
val = val["__value__"]
|
||||
inputs[x] = val
|
||||
|
||||
if input_type == "INT":
|
||||
val = int(val)
|
||||
|
||||
2
nodes.py
2
nodes.py
@@ -951,7 +951,7 @@ class UNETLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
|
||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True})
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "load_unet"
|
||||
|
||||
@@ -22,7 +22,7 @@ alembic
|
||||
SQLAlchemy
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.7
|
||||
comfy-aimdo>=0.2.4
|
||||
comfy-aimdo>=0.2.6
|
||||
requests
|
||||
|
||||
#non essential dependencies:
|
||||
|
||||
Reference in New Issue
Block a user