mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-12 16:49:57 +00:00
feat(isolation): DynamicVRAM compatibility for process isolation
DynamicVRAM's on-demand model loading/offloading conflicted with process isolation in three ways: RPC tensor transport stalls from mid-call GPU offload, race conditions between model lifecycle and active RPC operations, and false positive memory leak detection from changed finalizer patterns. - Marshal CUDA tensors to CPU before RPC transport for dynamic models - Add operation state tracking + quiescence waits at workflow boundaries - Distinguish proxy reference release from actual leaks in cleanup_models_gc - Fix init order: DynamicVRAM must initialize before isolation proxies - Add RPC timeouts to prevent indefinite hangs on model unavailability - Prevent proxy-of-proxy chains from DynamicVRAM model reload cycles - Add torch.device/torch.dtype serializers for new DynamicVRAM RPC paths - Guard isolation overhead so non-isolated workflows are unaffected - Migrate env var to PYISOLATE_CHILD
This commit is contained in:
1
.pyisolate_venvs/ComfyUI-DepthAnythingV2/cache/cache_key
vendored
Normal file
1
.pyisolate_venvs/ComfyUI-DepthAnythingV2/cache/cache_key
vendored
Normal file
@@ -0,0 +1 @@
|
||||
f03e4c88e21504c3
|
||||
81
.pyisolate_venvs/ComfyUI-DepthAnythingV2/cache/node_info.json
vendored
Normal file
81
.pyisolate_venvs/ComfyUI-DepthAnythingV2/cache/node_info.json
vendored
Normal file
@@ -0,0 +1,81 @@
|
||||
{
|
||||
"DepthAnything_V2": {
|
||||
"input_types": {
|
||||
"required": {
|
||||
"da_model": {
|
||||
"__pyisolate_tuple__": [
|
||||
"DAMODEL"
|
||||
]
|
||||
},
|
||||
"images": {
|
||||
"__pyisolate_tuple__": [
|
||||
"IMAGE"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"return_types": [
|
||||
"IMAGE"
|
||||
],
|
||||
"return_names": [
|
||||
"image"
|
||||
],
|
||||
"function": "process",
|
||||
"category": "DepthAnythingV2",
|
||||
"output_node": false,
|
||||
"output_is_list": null,
|
||||
"is_v3": false,
|
||||
"display_name": "Depth Anything V2"
|
||||
},
|
||||
"DownloadAndLoadDepthAnythingV2Model": {
|
||||
"input_types": {
|
||||
"required": {
|
||||
"model": {
|
||||
"__pyisolate_tuple__": [
|
||||
[
|
||||
"depth_anything_v2_vits_fp16.safetensors",
|
||||
"depth_anything_v2_vits_fp32.safetensors",
|
||||
"depth_anything_v2_vitb_fp16.safetensors",
|
||||
"depth_anything_v2_vitb_fp32.safetensors",
|
||||
"depth_anything_v2_vitl_fp16.safetensors",
|
||||
"depth_anything_v2_vitl_fp32.safetensors",
|
||||
"depth_anything_v2_vitg_fp32.safetensors",
|
||||
"depth_anything_v2_metric_hypersim_vitl_fp32.safetensors",
|
||||
"depth_anything_v2_metric_vkitti_vitl_fp32.safetensors"
|
||||
],
|
||||
{
|
||||
"default": "depth_anything_v2_vitl_fp32.safetensors"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"optional": {
|
||||
"precision": {
|
||||
"__pyisolate_tuple__": [
|
||||
[
|
||||
"auto",
|
||||
"bf16",
|
||||
"fp16",
|
||||
"fp32"
|
||||
],
|
||||
{
|
||||
"default": "auto"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"return_types": [
|
||||
"DAMODEL"
|
||||
],
|
||||
"return_names": [
|
||||
"da_v2_model"
|
||||
],
|
||||
"function": "loadmodel",
|
||||
"category": "DepthAnythingV2",
|
||||
"output_node": false,
|
||||
"output_is_list": null,
|
||||
"is_v3": false,
|
||||
"display_name": "DownloadAndLoadDepthAnythingV2Model"
|
||||
}
|
||||
}
|
||||
1
.pyisolate_venvs/ComfyUI-IsolationToolkit/cache/cache_key
vendored
Normal file
1
.pyisolate_venvs/ComfyUI-IsolationToolkit/cache/cache_key
vendored
Normal file
@@ -0,0 +1 @@
|
||||
4b90e6876f4c0b8c
|
||||
93028
.pyisolate_venvs/ComfyUI-IsolationToolkit/cache/node_info.json
vendored
Normal file
93028
.pyisolate_venvs/ComfyUI-IsolationToolkit/cache/node_info.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
@@ -17,7 +17,7 @@ from importlib.metadata import version
|
||||
import requests
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from utils.install_util import get_missing_requirements_message, requirements_path
|
||||
from utils.install_util import get_missing_requirements_message, get_required_packages_versions
|
||||
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
import app.logger
|
||||
@@ -45,25 +45,7 @@ def get_installed_frontend_version():
|
||||
|
||||
|
||||
def get_required_frontend_version():
|
||||
"""Get the required frontend version from requirements.txt."""
|
||||
try:
|
||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("comfyui-frontend-package=="):
|
||||
version_str = line.split("==")[-1]
|
||||
if not is_valid_version(version_str):
|
||||
logging.error(f"Invalid version format in requirements.txt: {version_str}")
|
||||
return None
|
||||
return version_str
|
||||
logging.error("comfyui-frontend-package not found in requirements.txt")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
logging.error("requirements.txt not found. Cannot determine required frontend version.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading requirements.txt: {e}")
|
||||
return None
|
||||
return get_required_packages_versions().get("comfyui-frontend-package", None)
|
||||
|
||||
|
||||
def check_frontend_version():
|
||||
@@ -217,25 +199,7 @@ class FrontendManager:
|
||||
|
||||
@classmethod
|
||||
def get_required_templates_version(cls) -> str:
|
||||
"""Get the required workflow templates version from requirements.txt."""
|
||||
try:
|
||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("comfyui-workflow-templates=="):
|
||||
version_str = line.split("==")[-1]
|
||||
if not is_valid_version(version_str):
|
||||
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
|
||||
return None
|
||||
return version_str
|
||||
logging.error("comfyui-workflow-templates not found in requirements.txt")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
logging.error("requirements.txt not found. Cannot determine required templates version.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading requirements.txt: {e}")
|
||||
return None
|
||||
return get_required_packages_versions().get("comfyui-workflow-templates", None)
|
||||
|
||||
@classmethod
|
||||
def default_frontend_path(cls) -> str:
|
||||
|
||||
@@ -146,6 +146,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
|
||||
|
||||
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
|
||||
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
|
||||
@@ -159,7 +160,6 @@ class PerformanceFeature(enum.Enum):
|
||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||
CublasOps = "cublas_ops"
|
||||
AutoTune = "autotune"
|
||||
DynamicVRAM = "dynamic_vram"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
@@ -262,4 +262,4 @@ else:
|
||||
args.fast = set(args.fast)
|
||||
|
||||
def enables_dynamic_vram():
|
||||
return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only
|
||||
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu
|
||||
|
||||
@@ -214,7 +214,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
||||
matches = torch.nonzero(mask)
|
||||
if torch.numel(matches) == 0:
|
||||
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||
return # substep from multi-step sampler: keep self._step from the last full step
|
||||
self._step = int(matches[0].item())
|
||||
|
||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||
|
||||
@@ -64,7 +64,7 @@ class EnumHookScope(enum.Enum):
|
||||
HookedOnly = "hooked_only"
|
||||
|
||||
|
||||
_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
|
||||
_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
|
||||
class _HookRef:
|
||||
|
||||
@@ -26,6 +26,7 @@ PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||
_MODEL_PATCHER_IDLE_TIMEOUT_MS = 120000
|
||||
|
||||
|
||||
def initialize_proxies() -> None:
|
||||
@@ -168,6 +169,11 @@ def _get_class_types_for_extension(extension_name: str) -> Set[str]:
|
||||
|
||||
async def notify_execution_graph(needed_class_types: Set[str]) -> None:
|
||||
"""Evict running extensions not needed for current execution."""
|
||||
await wait_for_model_patcher_quiescence(
|
||||
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||
fail_loud=True,
|
||||
marker="ISO:notify_graph_wait_idle",
|
||||
)
|
||||
|
||||
async def _stop_extension(
|
||||
ext_name: str, extension: "ComfyNodeExtension", reason: str
|
||||
@@ -182,22 +188,33 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None:
|
||||
scan_shm_forensics("ISO:stop_extension", refresh_model_context=True)
|
||||
|
||||
scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True)
|
||||
isolated_class_types_in_graph = needed_class_types.intersection(
|
||||
{spec.node_name for spec in _ISOLATED_NODE_SPECS}
|
||||
)
|
||||
graph_uses_isolation = bool(isolated_class_types_in_graph)
|
||||
logger.debug(
|
||||
"%s ISO:notify_graph_start running=%d needed=%d",
|
||||
LOG_PREFIX,
|
||||
len(_RUNNING_EXTENSIONS),
|
||||
len(needed_class_types),
|
||||
)
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
ext_class_types = _get_class_types_for_extension(ext_name)
|
||||
if graph_uses_isolation:
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
ext_class_types = _get_class_types_for_extension(ext_name)
|
||||
|
||||
# If NONE of this extension's nodes are in the execution graph → evict
|
||||
if not ext_class_types.intersection(needed_class_types):
|
||||
await _stop_extension(
|
||||
ext_name,
|
||||
extension,
|
||||
"isolated custom_node not in execution graph, evicting",
|
||||
)
|
||||
# If NONE of this extension's nodes are in the execution graph -> evict.
|
||||
if not ext_class_types.intersection(needed_class_types):
|
||||
await _stop_extension(
|
||||
ext_name,
|
||||
extension,
|
||||
"isolated custom_node not in execution graph, evicting",
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"%s ISO:notify_graph_skip_evict running=%d reason=no isolated nodes in graph",
|
||||
LOG_PREFIX,
|
||||
len(_RUNNING_EXTENSIONS),
|
||||
)
|
||||
|
||||
# Isolated child processes add steady VRAM pressure; reclaim host-side models
|
||||
# at workflow boundaries so subsequent host nodes (e.g. CLIP encode) keep headroom.
|
||||
@@ -211,7 +228,7 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None:
|
||||
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES,
|
||||
)
|
||||
free_before = model_management.get_free_memory(device)
|
||||
if free_before < required and _RUNNING_EXTENSIONS:
|
||||
if free_before < required and _RUNNING_EXTENSIONS and graph_uses_isolation:
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
await _stop_extension(
|
||||
ext_name,
|
||||
@@ -237,6 +254,11 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None:
|
||||
|
||||
|
||||
async def flush_running_extensions_transport_state() -> int:
|
||||
await wait_for_model_patcher_quiescence(
|
||||
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||
fail_loud=True,
|
||||
marker="ISO:flush_transport_wait_idle",
|
||||
)
|
||||
total_flushed = 0
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
flush_fn = getattr(extension, "flush_transport_state", None)
|
||||
@@ -263,6 +285,50 @@ async def flush_running_extensions_transport_state() -> int:
|
||||
return total_flushed
|
||||
|
||||
|
||||
async def wait_for_model_patcher_quiescence(
|
||||
timeout_ms: int = _MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||
*,
|
||||
fail_loud: bool = False,
|
||||
marker: str = "ISO:wait_model_patcher_idle",
|
||||
) -> bool:
|
||||
try:
|
||||
from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry
|
||||
|
||||
registry = ModelPatcherRegistry()
|
||||
start = time.perf_counter()
|
||||
idle = await registry.wait_all_idle(timeout_ms)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000.0
|
||||
if idle:
|
||||
logger.debug(
|
||||
"%s %s idle=1 timeout_ms=%d elapsed_ms=%.3f",
|
||||
LOG_PREFIX,
|
||||
marker,
|
||||
timeout_ms,
|
||||
elapsed_ms,
|
||||
)
|
||||
return True
|
||||
|
||||
states = await registry.get_all_operation_states()
|
||||
logger.error(
|
||||
"%s %s idle_timeout timeout_ms=%d elapsed_ms=%.3f states=%s",
|
||||
LOG_PREFIX,
|
||||
marker,
|
||||
timeout_ms,
|
||||
elapsed_ms,
|
||||
states,
|
||||
)
|
||||
if fail_loud:
|
||||
raise TimeoutError(
|
||||
f"ModelPatcherRegistry did not quiesce within {timeout_ms} ms"
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
if fail_loud:
|
||||
raise
|
||||
logger.debug("%s %s failed", LOG_PREFIX, marker, exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def get_claimed_paths() -> Set[Path]:
|
||||
return _CLAIMED_PATHS
|
||||
|
||||
@@ -320,6 +386,7 @@ __all__ = [
|
||||
"await_isolation_loading",
|
||||
"notify_execution_graph",
|
||||
"flush_running_extensions_transport_state",
|
||||
"wait_for_model_patcher_quiescence",
|
||||
"get_claimed_paths",
|
||||
"update_rpc_event_loops",
|
||||
"IsolatedNodeSpec",
|
||||
|
||||
@@ -83,6 +83,33 @@ class ComfyUIAdapter(IsolationAdapter):
|
||||
logging.getLogger(pkg_name).setLevel(logging.ERROR)
|
||||
|
||||
def register_serializers(self, registry: SerializerRegistryProtocol) -> None:
|
||||
import torch
|
||||
|
||||
def serialize_device(obj: Any) -> Dict[str, Any]:
|
||||
return {"__type__": "device", "device_str": str(obj)}
|
||||
|
||||
def deserialize_device(data: Dict[str, Any]) -> Any:
|
||||
return torch.device(data["device_str"])
|
||||
|
||||
registry.register("device", serialize_device, deserialize_device)
|
||||
|
||||
_VALID_DTYPES = {
|
||||
"float16", "float32", "float64", "bfloat16",
|
||||
"int8", "int16", "int32", "int64",
|
||||
"uint8", "bool",
|
||||
}
|
||||
|
||||
def serialize_dtype(obj: Any) -> Dict[str, Any]:
|
||||
return {"__type__": "dtype", "dtype_str": str(obj)}
|
||||
|
||||
def deserialize_dtype(data: Dict[str, Any]) -> Any:
|
||||
dtype_name = data["dtype_str"].replace("torch.", "")
|
||||
if dtype_name not in _VALID_DTYPES:
|
||||
raise ValueError(f"Invalid dtype: {data['dtype_str']}")
|
||||
return getattr(torch, dtype_name)
|
||||
|
||||
registry.register("dtype", serialize_dtype, deserialize_dtype)
|
||||
|
||||
def serialize_model_patcher(obj: Any) -> Dict[str, Any]:
|
||||
# Child-side: must already have _instance_id (proxy)
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
@@ -193,6 +220,10 @@ class ComfyUIAdapter(IsolationAdapter):
|
||||
f"ModelSampling in child lacks _instance_id: "
|
||||
f"{type(obj).__module__}.{type(obj).__name__}"
|
||||
)
|
||||
# Host-side pass-through for proxies: do not re-register a proxy as a
|
||||
# new ModelSamplingRef, or we create proxy-of-proxy indirection.
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
|
||||
# Host-side: register with ModelSamplingRegistry and return JSON-safe dict
|
||||
ms_id = ModelSamplingRegistry().register(obj)
|
||||
return {"__type__": "ModelSamplingRef", "ms_id": ms_id}
|
||||
@@ -211,22 +242,21 @@ class ComfyUIAdapter(IsolationAdapter):
|
||||
else:
|
||||
return ModelSamplingRegistry()._get_instance(data["ms_id"])
|
||||
|
||||
# Register ModelSampling type and proxy
|
||||
registry.register(
|
||||
"ModelSamplingDiscrete",
|
||||
serialize_model_sampling,
|
||||
deserialize_model_sampling,
|
||||
)
|
||||
registry.register(
|
||||
"ModelSamplingContinuousEDM",
|
||||
serialize_model_sampling,
|
||||
deserialize_model_sampling,
|
||||
)
|
||||
registry.register(
|
||||
"ModelSamplingContinuousV",
|
||||
serialize_model_sampling,
|
||||
deserialize_model_sampling,
|
||||
)
|
||||
# Register all ModelSampling* and StableCascadeSampling classes dynamically
|
||||
import comfy.model_sampling
|
||||
|
||||
for ms_cls in vars(comfy.model_sampling).values():
|
||||
if not isinstance(ms_cls, type):
|
||||
continue
|
||||
if not issubclass(ms_cls, torch.nn.Module):
|
||||
continue
|
||||
if not (ms_cls.__name__.startswith("ModelSampling") or ms_cls.__name__ == "StableCascadeSampling"):
|
||||
continue
|
||||
registry.register(
|
||||
ms_cls.__name__,
|
||||
serialize_model_sampling,
|
||||
deserialize_model_sampling,
|
||||
)
|
||||
registry.register(
|
||||
"ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling
|
||||
)
|
||||
|
||||
@@ -389,7 +389,8 @@ class ComfyNodeExtension(ExtensionBase):
|
||||
getattr(self, "name", "?"),
|
||||
node_name,
|
||||
)
|
||||
return self._wrap_unpicklable_objects(result)
|
||||
wrapped = self._wrap_unpicklable_objects(result)
|
||||
return wrapped
|
||||
|
||||
if not isinstance(result, tuple):
|
||||
result = (result,)
|
||||
@@ -400,7 +401,8 @@ class ComfyNodeExtension(ExtensionBase):
|
||||
node_name,
|
||||
len(result),
|
||||
)
|
||||
return self._wrap_unpicklable_objects(result)
|
||||
wrapped = self._wrap_unpicklable_objects(result)
|
||||
return wrapped
|
||||
|
||||
async def flush_transport_state(self) -> int:
|
||||
if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") != "1":
|
||||
@@ -443,7 +445,10 @@ class ComfyNodeExtension(ExtensionBase):
|
||||
if isinstance(data, (str, int, float, bool, type(None))):
|
||||
return data
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.detach() if data.requires_grad else data
|
||||
tensor = data.detach() if data.requires_grad else data
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1" and tensor.device.type != "cpu":
|
||||
return tensor.cpu()
|
||||
return tensor
|
||||
|
||||
# Special-case clip vision outputs: preserve attribute access by packing fields
|
||||
if hasattr(data, "penultimate_hidden_states") or hasattr(
|
||||
|
||||
@@ -338,6 +338,34 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
|
||||
def apply_model(self, *args, **kwargs) -> Any:
|
||||
import torch
|
||||
|
||||
def _preferred_device() -> Any:
|
||||
for value in args:
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
return None
|
||||
|
||||
def _move_result_to_device(obj: Any, device: Any) -> Any:
|
||||
if device is None:
|
||||
return obj
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj.to(device) if obj.device != device else obj
|
||||
if isinstance(obj, dict):
|
||||
return {k: _move_result_to_device(v, device) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_move_result_to_device(v, device) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(_move_result_to_device(v, device) for v in obj)
|
||||
return obj
|
||||
|
||||
# DynamicVRAM models must keep load/offload decisions in host process.
|
||||
# Child-side CUDA staging here can deadlock before first inference RPC.
|
||||
if self.is_dynamic():
|
||||
out = self._call_rpc("inner_model_apply_model", args, kwargs)
|
||||
return _move_result_to_device(out, _preferred_device())
|
||||
|
||||
required_bytes = self._cpu_tensor_bytes(args) + self._cpu_tensor_bytes(kwargs)
|
||||
self._ensure_apply_model_headroom(required_bytes)
|
||||
|
||||
@@ -360,7 +388,8 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
|
||||
args_cuda = _to_cuda(args)
|
||||
kwargs_cuda = _to_cuda(kwargs)
|
||||
|
||||
return self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda)
|
||||
out = self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda)
|
||||
return _move_result_to_device(out, _preferred_device())
|
||||
|
||||
def model_state_dict(self, filter_prefix: Optional[str] = None) -> Any:
|
||||
keys = self._call_rpc("model_state_dict", filter_prefix)
|
||||
@@ -526,6 +555,13 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
|
||||
def memory_required(self, input_shape: Any) -> Any:
|
||||
return self._call_rpc("memory_required", input_shape)
|
||||
|
||||
def get_operation_state(self) -> Dict[str, Any]:
|
||||
state = self._call_rpc("get_operation_state")
|
||||
return state if isinstance(state, dict) else {}
|
||||
|
||||
def wait_for_idle(self, timeout_ms: int = 0) -> bool:
|
||||
return bool(self._call_rpc("wait_for_idle", timeout_ms))
|
||||
|
||||
def is_dynamic(self) -> bool:
|
||||
return bool(self._call_rpc("is_dynamic"))
|
||||
|
||||
@@ -771,6 +807,7 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
|
||||
class _InnerModelProxy:
|
||||
def __init__(self, parent: ModelPatcherProxy):
|
||||
self._parent = parent
|
||||
self._model_sampling = None
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name.startswith("_"):
|
||||
@@ -793,7 +830,11 @@ class _InnerModelProxy:
|
||||
manage_lifecycle=False,
|
||||
)
|
||||
if name == "model_sampling":
|
||||
return self._parent._call_rpc("get_model_object", "model_sampling")
|
||||
if self._model_sampling is None:
|
||||
self._model_sampling = self._parent._call_rpc(
|
||||
"get_model_object", "model_sampling"
|
||||
)
|
||||
return self._model_sampling
|
||||
if name == "extra_conds_shapes":
|
||||
return lambda *a, **k: self._parent._call_rpc(
|
||||
"inner_model_extra_conds_shapes", a, k
|
||||
|
||||
@@ -2,8 +2,12 @@
|
||||
# RPC server for ModelPatcher isolation (child process)
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, List
|
||||
|
||||
try:
|
||||
@@ -43,12 +47,191 @@ from comfy.isolation.proxies.base import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OperationState:
|
||||
lease: threading.Lock = field(default_factory=threading.Lock)
|
||||
active_count: int = 0
|
||||
active_by_method: dict[str, int] = field(default_factory=dict)
|
||||
total_operations: int = 0
|
||||
last_method: Optional[str] = None
|
||||
last_started_ts: Optional[float] = None
|
||||
last_ended_ts: Optional[float] = None
|
||||
last_elapsed_ms: Optional[float] = None
|
||||
last_error: Optional[str] = None
|
||||
last_thread_id: Optional[int] = None
|
||||
last_loop_id: Optional[int] = None
|
||||
|
||||
|
||||
class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "model"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._pending_cleanup_ids: set[str] = set()
|
||||
self._operation_states: dict[str, _OperationState] = {}
|
||||
self._operation_state_cv = threading.Condition(self._lock)
|
||||
|
||||
def _get_or_create_operation_state(self, instance_id: str) -> _OperationState:
|
||||
state = self._operation_states.get(instance_id)
|
||||
if state is None:
|
||||
state = _OperationState()
|
||||
self._operation_states[instance_id] = state
|
||||
return state
|
||||
|
||||
def _begin_operation(self, instance_id: str, method_name: str) -> tuple[float, float]:
|
||||
start_epoch = time.time()
|
||||
start_perf = time.perf_counter()
|
||||
with self._operation_state_cv:
|
||||
state = self._get_or_create_operation_state(instance_id)
|
||||
state.active_count += 1
|
||||
state.active_by_method[method_name] = (
|
||||
state.active_by_method.get(method_name, 0) + 1
|
||||
)
|
||||
state.total_operations += 1
|
||||
state.last_method = method_name
|
||||
state.last_started_ts = start_epoch
|
||||
state.last_thread_id = threading.get_ident()
|
||||
try:
|
||||
state.last_loop_id = id(asyncio.get_running_loop())
|
||||
except RuntimeError:
|
||||
state.last_loop_id = None
|
||||
logger.debug(
|
||||
"ISO:registry_op_start instance_id=%s method=%s start_ts=%.6f thread=%s loop=%s",
|
||||
instance_id,
|
||||
method_name,
|
||||
start_epoch,
|
||||
threading.get_ident(),
|
||||
state.last_loop_id,
|
||||
)
|
||||
return start_epoch, start_perf
|
||||
|
||||
def _end_operation(
|
||||
self,
|
||||
instance_id: str,
|
||||
method_name: str,
|
||||
start_perf: float,
|
||||
error: Optional[BaseException] = None,
|
||||
) -> None:
|
||||
end_epoch = time.time()
|
||||
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
|
||||
with self._operation_state_cv:
|
||||
state = self._get_or_create_operation_state(instance_id)
|
||||
state.active_count = max(0, state.active_count - 1)
|
||||
if method_name in state.active_by_method:
|
||||
remaining = state.active_by_method[method_name] - 1
|
||||
if remaining <= 0:
|
||||
state.active_by_method.pop(method_name, None)
|
||||
else:
|
||||
state.active_by_method[method_name] = remaining
|
||||
state.last_ended_ts = end_epoch
|
||||
state.last_elapsed_ms = elapsed_ms
|
||||
state.last_error = None if error is None else repr(error)
|
||||
if state.active_count == 0:
|
||||
self._operation_state_cv.notify_all()
|
||||
logger.debug(
|
||||
"ISO:registry_op_end instance_id=%s method=%s end_ts=%.6f elapsed_ms=%.3f error=%s",
|
||||
instance_id,
|
||||
method_name,
|
||||
end_epoch,
|
||||
elapsed_ms,
|
||||
None if error is None else type(error).__name__,
|
||||
)
|
||||
|
||||
def _run_operation_with_lease(self, instance_id: str, method_name: str, fn):
|
||||
with self._operation_state_cv:
|
||||
state = self._get_or_create_operation_state(instance_id)
|
||||
lease = state.lease
|
||||
with lease:
|
||||
_, start_perf = self._begin_operation(instance_id, method_name)
|
||||
try:
|
||||
result = fn()
|
||||
except Exception as exc:
|
||||
self._end_operation(instance_id, method_name, start_perf, error=exc)
|
||||
raise
|
||||
self._end_operation(instance_id, method_name, start_perf)
|
||||
return result
|
||||
|
||||
def _snapshot_operation_state(self, instance_id: str) -> dict[str, Any]:
|
||||
with self._operation_state_cv:
|
||||
state = self._operation_states.get(instance_id)
|
||||
if state is None:
|
||||
return {
|
||||
"instance_id": instance_id,
|
||||
"active_count": 0,
|
||||
"active_methods": [],
|
||||
"total_operations": 0,
|
||||
"last_method": None,
|
||||
"last_started_ts": None,
|
||||
"last_ended_ts": None,
|
||||
"last_elapsed_ms": None,
|
||||
"last_error": None,
|
||||
"last_thread_id": None,
|
||||
"last_loop_id": None,
|
||||
}
|
||||
return {
|
||||
"instance_id": instance_id,
|
||||
"active_count": state.active_count,
|
||||
"active_methods": sorted(state.active_by_method.keys()),
|
||||
"total_operations": state.total_operations,
|
||||
"last_method": state.last_method,
|
||||
"last_started_ts": state.last_started_ts,
|
||||
"last_ended_ts": state.last_ended_ts,
|
||||
"last_elapsed_ms": state.last_elapsed_ms,
|
||||
"last_error": state.last_error,
|
||||
"last_thread_id": state.last_thread_id,
|
||||
"last_loop_id": state.last_loop_id,
|
||||
}
|
||||
|
||||
def unregister_sync(self, instance_id: str) -> None:
|
||||
with self._operation_state_cv:
|
||||
instance = self._registry.pop(instance_id, None)
|
||||
if instance is not None:
|
||||
self._id_map.pop(id(instance), None)
|
||||
self._pending_cleanup_ids.discard(instance_id)
|
||||
self._operation_states.pop(instance_id, None)
|
||||
self._operation_state_cv.notify_all()
|
||||
|
||||
async def get_operation_state(self, instance_id: str) -> dict[str, Any]:
|
||||
return self._snapshot_operation_state(instance_id)
|
||||
|
||||
async def get_all_operation_states(self) -> dict[str, dict[str, Any]]:
|
||||
with self._operation_state_cv:
|
||||
ids = sorted(self._operation_states.keys())
|
||||
return {instance_id: self._snapshot_operation_state(instance_id) for instance_id in ids}
|
||||
|
||||
async def wait_for_idle(self, instance_id: str, timeout_ms: int = 0) -> bool:
|
||||
timeout_s = None if timeout_ms <= 0 else (timeout_ms / 1000.0)
|
||||
deadline = None if timeout_s is None else (time.monotonic() + timeout_s)
|
||||
with self._operation_state_cv:
|
||||
while True:
|
||||
active = self._operation_states.get(instance_id)
|
||||
if active is None or active.active_count == 0:
|
||||
return True
|
||||
if deadline is None:
|
||||
self._operation_state_cv.wait()
|
||||
continue
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
return False
|
||||
self._operation_state_cv.wait(timeout=remaining)
|
||||
|
||||
async def wait_all_idle(self, timeout_ms: int = 0) -> bool:
|
||||
timeout_s = None if timeout_ms <= 0 else (timeout_ms / 1000.0)
|
||||
deadline = None if timeout_s is None else (time.monotonic() + timeout_s)
|
||||
with self._operation_state_cv:
|
||||
while True:
|
||||
has_active = any(
|
||||
state.active_count > 0 for state in self._operation_states.values()
|
||||
)
|
||||
if not has_active:
|
||||
return True
|
||||
if deadline is None:
|
||||
self._operation_state_cv.wait()
|
||||
continue
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
return False
|
||||
self._operation_state_cv.wait(timeout=remaining)
|
||||
|
||||
async def clone(self, instance_id: str) -> str:
|
||||
instance = self._get_instance(instance_id)
|
||||
@@ -73,8 +256,14 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
)
|
||||
|
||||
registry = ModelSamplingRegistry()
|
||||
sampling_id = registry.register(result)
|
||||
# Preserve identity when upstream already returned a proxy. Re-registering
|
||||
# a proxy object creates proxy-of-proxy call chains.
|
||||
if isinstance(result, ModelSamplingProxy):
|
||||
sampling_id = result._instance_id
|
||||
else:
|
||||
sampling_id = registry.register(result)
|
||||
return ModelSamplingProxy(sampling_id, registry)
|
||||
|
||||
return detach_if_grad(result)
|
||||
|
||||
async def get_model_options(self, instance_id: str) -> dict:
|
||||
@@ -163,7 +352,11 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
return self._get_instance(instance_id).lowvram_patch_counter()
|
||||
|
||||
async def memory_required(self, instance_id: str, input_shape: Any) -> Any:
|
||||
return self._get_instance(instance_id).memory_required(input_shape)
|
||||
return self._run_operation_with_lease(
|
||||
instance_id,
|
||||
"memory_required",
|
||||
lambda: self._get_instance(instance_id).memory_required(input_shape),
|
||||
)
|
||||
|
||||
async def is_dynamic(self, instance_id: str) -> bool:
|
||||
instance = self._get_instance(instance_id)
|
||||
@@ -186,7 +379,11 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
return None
|
||||
|
||||
async def model_dtype(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).model_dtype()
|
||||
return self._run_operation_with_lease(
|
||||
instance_id,
|
||||
"model_dtype",
|
||||
lambda: self._get_instance(instance_id).model_dtype(),
|
||||
)
|
||||
|
||||
async def model_patches_to(self, instance_id: str, device: Any) -> Any:
|
||||
return self._get_instance(instance_id).model_patches_to(device)
|
||||
@@ -198,8 +395,12 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
extra_memory: Any,
|
||||
force_patch_weights: bool = False,
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).partially_load(
|
||||
device, extra_memory, force_patch_weights=force_patch_weights
|
||||
return self._run_operation_with_lease(
|
||||
instance_id,
|
||||
"partially_load",
|
||||
lambda: self._get_instance(instance_id).partially_load(
|
||||
device, extra_memory, force_patch_weights=force_patch_weights
|
||||
),
|
||||
)
|
||||
|
||||
async def partially_unload(
|
||||
@@ -209,8 +410,12 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
memory_to_free: int = 0,
|
||||
force_patch_weights: bool = False,
|
||||
) -> int:
|
||||
return self._get_instance(instance_id).partially_unload(
|
||||
device_to, memory_to_free, force_patch_weights
|
||||
return self._run_operation_with_lease(
|
||||
instance_id,
|
||||
"partially_unload",
|
||||
lambda: self._get_instance(instance_id).partially_unload(
|
||||
device_to, memory_to_free, force_patch_weights
|
||||
),
|
||||
)
|
||||
|
||||
async def load(
|
||||
@@ -221,8 +426,12 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
force_patch_weights: bool = False,
|
||||
full_load: bool = False,
|
||||
) -> None:
|
||||
self._get_instance(instance_id).load(
|
||||
device_to, lowvram_model_memory, force_patch_weights, full_load
|
||||
self._run_operation_with_lease(
|
||||
instance_id,
|
||||
"load",
|
||||
lambda: self._get_instance(instance_id).load(
|
||||
device_to, lowvram_model_memory, force_patch_weights, full_load
|
||||
),
|
||||
)
|
||||
|
||||
async def patch_model(
|
||||
@@ -233,20 +442,29 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
load_weights: bool = True,
|
||||
force_patch_weights: bool = False,
|
||||
) -> None:
|
||||
try:
|
||||
self._get_instance(instance_id).patch_model(
|
||||
device_to, lowvram_model_memory, load_weights, force_patch_weights
|
||||
)
|
||||
except AttributeError as e:
|
||||
logger.error(
|
||||
f"Isolation Error: Failed to patch model attribute: {e}. Skipping."
|
||||
)
|
||||
return
|
||||
def _invoke() -> None:
|
||||
try:
|
||||
self._get_instance(instance_id).patch_model(
|
||||
device_to, lowvram_model_memory, load_weights, force_patch_weights
|
||||
)
|
||||
except AttributeError as e:
|
||||
logger.error(
|
||||
f"Isolation Error: Failed to patch model attribute: {e}. Skipping."
|
||||
)
|
||||
return
|
||||
|
||||
self._run_operation_with_lease(instance_id, "patch_model", _invoke)
|
||||
|
||||
async def unpatch_model(
|
||||
self, instance_id: str, device_to: Any = None, unpatch_weights: bool = True
|
||||
) -> None:
|
||||
self._get_instance(instance_id).unpatch_model(device_to, unpatch_weights)
|
||||
self._run_operation_with_lease(
|
||||
instance_id,
|
||||
"unpatch_model",
|
||||
lambda: self._get_instance(instance_id).unpatch_model(
|
||||
device_to, unpatch_weights
|
||||
),
|
||||
)
|
||||
|
||||
async def detach(self, instance_id: str, unpatch_all: bool = True) -> None:
|
||||
self._get_instance(instance_id).detach(unpatch_all)
|
||||
@@ -262,26 +480,29 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
self._get_instance(instance_id).pre_run()
|
||||
|
||||
async def cleanup(self, instance_id: str) -> None:
|
||||
try:
|
||||
instance = self._get_instance(instance_id)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"ModelPatcher cleanup requested for missing instance %s",
|
||||
instance_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
def _invoke() -> None:
|
||||
try:
|
||||
instance = self._get_instance(instance_id)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"ModelPatcher cleanup requested for missing instance %s",
|
||||
instance_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
instance.cleanup()
|
||||
finally:
|
||||
with self._lock:
|
||||
self._pending_cleanup_ids.add(instance_id)
|
||||
gc.collect()
|
||||
try:
|
||||
instance.cleanup()
|
||||
finally:
|
||||
with self._lock:
|
||||
self._pending_cleanup_ids.add(instance_id)
|
||||
gc.collect()
|
||||
|
||||
self._run_operation_with_lease(instance_id, "cleanup", _invoke)
|
||||
|
||||
def sweep_pending_cleanup(self) -> int:
|
||||
removed = 0
|
||||
with self._lock:
|
||||
with self._operation_state_cv:
|
||||
pending_ids = list(self._pending_cleanup_ids)
|
||||
self._pending_cleanup_ids.clear()
|
||||
for instance_id in pending_ids:
|
||||
@@ -289,17 +510,21 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
if instance is None:
|
||||
continue
|
||||
self._id_map.pop(id(instance), None)
|
||||
self._operation_states.pop(instance_id, None)
|
||||
removed += 1
|
||||
self._operation_state_cv.notify_all()
|
||||
|
||||
gc.collect()
|
||||
return removed
|
||||
|
||||
def purge_all(self) -> int:
|
||||
with self._lock:
|
||||
with self._operation_state_cv:
|
||||
removed = len(self._registry)
|
||||
self._registry.clear()
|
||||
self._id_map.clear()
|
||||
self._pending_cleanup_ids.clear()
|
||||
self._operation_states.clear()
|
||||
self._operation_state_cv.notify_all()
|
||||
gc.collect()
|
||||
return removed
|
||||
|
||||
@@ -743,17 +968,52 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
async def inner_model_memory_required(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).model.memory_required(*args, **kwargs)
|
||||
return self._run_operation_with_lease(
|
||||
instance_id,
|
||||
"inner_model_memory_required",
|
||||
lambda: self._get_instance(instance_id).model.memory_required(
|
||||
*args, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
async def inner_model_extra_conds_shapes(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).model.extra_conds_shapes(*args, **kwargs)
|
||||
return self._run_operation_with_lease(
|
||||
instance_id,
|
||||
"inner_model_extra_conds_shapes",
|
||||
lambda: self._get_instance(instance_id).model.extra_conds_shapes(
|
||||
*args, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
async def inner_model_extra_conds(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).model.extra_conds(*args, **kwargs)
|
||||
def _invoke() -> Any:
|
||||
result = self._get_instance(instance_id).model.extra_conds(*args, **kwargs)
|
||||
try:
|
||||
import torch
|
||||
import comfy.conds
|
||||
except Exception:
|
||||
return result
|
||||
|
||||
def _to_cpu(obj: Any) -> Any:
|
||||
if torch.is_tensor(obj):
|
||||
return obj.detach().cpu() if obj.device.type != "cpu" else obj
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cpu(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_to_cpu(v) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(_to_cpu(v) for v in obj)
|
||||
if isinstance(obj, comfy.conds.CONDRegular):
|
||||
return type(obj)(_to_cpu(obj.cond))
|
||||
return obj
|
||||
|
||||
return _to_cpu(result)
|
||||
|
||||
return self._run_operation_with_lease(instance_id, "inner_model_extra_conds", _invoke)
|
||||
|
||||
async def inner_model_state_dict(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
@@ -767,82 +1027,177 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
|
||||
async def inner_model_apply_model(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
target = getattr(instance, "load_device", None)
|
||||
if target is None and args and hasattr(args[0], "device"):
|
||||
target = args[0].device
|
||||
elif target is None:
|
||||
for v in kwargs.values():
|
||||
if hasattr(v, "device"):
|
||||
target = v.device
|
||||
break
|
||||
def _invoke() -> Any:
|
||||
import torch
|
||||
|
||||
def _move(obj):
|
||||
if target is None:
|
||||
instance = self._get_instance(instance_id)
|
||||
target = getattr(instance, "load_device", None)
|
||||
if target is None and args and hasattr(args[0], "device"):
|
||||
target = args[0].device
|
||||
elif target is None:
|
||||
for v in kwargs.values():
|
||||
if hasattr(v, "device"):
|
||||
target = v.device
|
||||
break
|
||||
|
||||
def _move(obj):
|
||||
if target is None:
|
||||
return obj
|
||||
if isinstance(obj, (tuple, list)):
|
||||
return type(obj)(_move(o) for o in obj)
|
||||
if hasattr(obj, "to"):
|
||||
return obj.to(target)
|
||||
return obj
|
||||
if isinstance(obj, (tuple, list)):
|
||||
return type(obj)(_move(o) for o in obj)
|
||||
if hasattr(obj, "to"):
|
||||
return obj.to(target)
|
||||
return obj
|
||||
|
||||
moved_args = tuple(_move(a) for a in args)
|
||||
moved_kwargs = {k: _move(v) for k, v in kwargs.items()}
|
||||
result = instance.model.apply_model(*moved_args, **moved_kwargs)
|
||||
return detach_if_grad(_move(result))
|
||||
moved_args = tuple(_move(a) for a in args)
|
||||
moved_kwargs = {k: _move(v) for k, v in kwargs.items()}
|
||||
result = instance.model.apply_model(*moved_args, **moved_kwargs)
|
||||
moved_result = detach_if_grad(_move(result))
|
||||
|
||||
# DynamicVRAM + isolation: returning CUDA tensors across RPC can stall
|
||||
# at the transport boundary. Marshal dynamic-path results as CPU and let
|
||||
# the proxy restore device placement in the child process.
|
||||
is_dynamic_fn = getattr(instance, "is_dynamic", None)
|
||||
if callable(is_dynamic_fn) and is_dynamic_fn():
|
||||
def _to_cpu(obj: Any) -> Any:
|
||||
if torch.is_tensor(obj):
|
||||
return obj.detach().cpu() if obj.device.type != "cpu" else obj
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cpu(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_to_cpu(v) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(_to_cpu(v) for v in obj)
|
||||
return obj
|
||||
|
||||
return _to_cpu(moved_result)
|
||||
return moved_result
|
||||
|
||||
return self._run_operation_with_lease(instance_id, "inner_model_apply_model", _invoke)
|
||||
|
||||
async def process_latent_in(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).model.process_latent_in(*args, **kwargs)
|
||||
)
|
||||
import torch
|
||||
|
||||
def _invoke() -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
result = detach_if_grad(instance.model.process_latent_in(*args, **kwargs))
|
||||
|
||||
# DynamicVRAM + isolation: returning CUDA tensors across RPC can stall
|
||||
# at the transport boundary. Marshal dynamic-path results as CPU and let
|
||||
# the proxy restore placement when needed.
|
||||
is_dynamic_fn = getattr(instance, "is_dynamic", None)
|
||||
if callable(is_dynamic_fn) and is_dynamic_fn():
|
||||
def _to_cpu(obj: Any) -> Any:
|
||||
if torch.is_tensor(obj):
|
||||
return obj.detach().cpu() if obj.device.type != "cpu" else obj
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cpu(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_to_cpu(v) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(_to_cpu(v) for v in obj)
|
||||
return obj
|
||||
|
||||
return _to_cpu(result)
|
||||
return result
|
||||
|
||||
return self._run_operation_with_lease(instance_id, "process_latent_in", _invoke)
|
||||
|
||||
async def process_latent_out(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
result = instance.model.process_latent_out(*args, **kwargs)
|
||||
try:
|
||||
target = None
|
||||
if args and hasattr(args[0], "device"):
|
||||
target = args[0].device
|
||||
elif kwargs:
|
||||
for v in kwargs.values():
|
||||
if hasattr(v, "device"):
|
||||
target = v.device
|
||||
break
|
||||
if target is not None and hasattr(result, "to"):
|
||||
return detach_if_grad(result.to(target))
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"process_latent_out: failed to move result to target device",
|
||||
exc_info=True,
|
||||
)
|
||||
return detach_if_grad(result)
|
||||
import torch
|
||||
|
||||
def _invoke() -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
result = instance.model.process_latent_out(*args, **kwargs)
|
||||
moved_result = None
|
||||
try:
|
||||
target = None
|
||||
if args and hasattr(args[0], "device"):
|
||||
target = args[0].device
|
||||
elif kwargs:
|
||||
for v in kwargs.values():
|
||||
if hasattr(v, "device"):
|
||||
target = v.device
|
||||
break
|
||||
if target is not None and hasattr(result, "to"):
|
||||
moved_result = detach_if_grad(result.to(target))
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"process_latent_out: failed to move result to target device",
|
||||
exc_info=True,
|
||||
)
|
||||
if moved_result is None:
|
||||
moved_result = detach_if_grad(result)
|
||||
|
||||
is_dynamic_fn = getattr(instance, "is_dynamic", None)
|
||||
if callable(is_dynamic_fn) and is_dynamic_fn():
|
||||
def _to_cpu(obj: Any) -> Any:
|
||||
if torch.is_tensor(obj):
|
||||
return obj.detach().cpu() if obj.device.type != "cpu" else obj
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cpu(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_to_cpu(v) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(_to_cpu(v) for v in obj)
|
||||
return obj
|
||||
|
||||
return _to_cpu(moved_result)
|
||||
return moved_result
|
||||
|
||||
return self._run_operation_with_lease(instance_id, "process_latent_out", _invoke)
|
||||
|
||||
async def scale_latent_inpaint(
|
||||
self, instance_id: str, args: tuple, kwargs: dict
|
||||
) -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
result = instance.model.scale_latent_inpaint(*args, **kwargs)
|
||||
try:
|
||||
target = None
|
||||
if args and hasattr(args[0], "device"):
|
||||
target = args[0].device
|
||||
elif kwargs:
|
||||
for v in kwargs.values():
|
||||
if hasattr(v, "device"):
|
||||
target = v.device
|
||||
break
|
||||
if target is not None and hasattr(result, "to"):
|
||||
return detach_if_grad(result.to(target))
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"scale_latent_inpaint: failed to move result to target device",
|
||||
exc_info=True,
|
||||
)
|
||||
return detach_if_grad(result)
|
||||
import torch
|
||||
|
||||
def _invoke() -> Any:
|
||||
instance = self._get_instance(instance_id)
|
||||
result = instance.model.scale_latent_inpaint(*args, **kwargs)
|
||||
moved_result = None
|
||||
try:
|
||||
target = None
|
||||
if args and hasattr(args[0], "device"):
|
||||
target = args[0].device
|
||||
elif kwargs:
|
||||
for v in kwargs.values():
|
||||
if hasattr(v, "device"):
|
||||
target = v.device
|
||||
break
|
||||
if target is not None and hasattr(result, "to"):
|
||||
moved_result = detach_if_grad(result.to(target))
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"scale_latent_inpaint: failed to move result to target device",
|
||||
exc_info=True,
|
||||
)
|
||||
if moved_result is None:
|
||||
moved_result = detach_if_grad(result)
|
||||
|
||||
is_dynamic_fn = getattr(instance, "is_dynamic", None)
|
||||
if callable(is_dynamic_fn) and is_dynamic_fn():
|
||||
def _to_cpu(obj: Any) -> Any:
|
||||
if torch.is_tensor(obj):
|
||||
return obj.detach().cpu() if obj.device.type != "cpu" else obj
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cpu(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_to_cpu(v) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(_to_cpu(v) for v in obj)
|
||||
return obj
|
||||
|
||||
return _to_cpu(moved_result)
|
||||
return moved_result
|
||||
|
||||
return self._run_operation_with_lease(
|
||||
instance_id, "scale_latent_inpaint", _invoke
|
||||
)
|
||||
|
||||
async def load_lora(
|
||||
self,
|
||||
|
||||
@@ -3,6 +3,9 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
@@ -16,6 +19,22 @@ from comfy.isolation.proxies.base import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _describe_value(obj: Any) -> str:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
torch = None
|
||||
try:
|
||||
if torch is not None and isinstance(obj, torch.Tensor):
|
||||
return (
|
||||
"Tensor(shape=%s,dtype=%s,device=%s,id=%s)"
|
||||
% (tuple(obj.shape), obj.dtype, obj.device, id(obj))
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return "%s(id=%s)" % (type(obj).__name__, id(obj))
|
||||
|
||||
|
||||
def _prefer_device(*tensors: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
@@ -49,6 +68,24 @@ def _to_device(obj: Any, device: Any) -> Any:
|
||||
return obj
|
||||
|
||||
|
||||
def _to_cpu_for_rpc(obj: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return obj
|
||||
if isinstance(obj, torch.Tensor):
|
||||
t = obj.detach() if obj.requires_grad else obj
|
||||
if t.is_cuda:
|
||||
return t.to("cpu")
|
||||
return t
|
||||
if isinstance(obj, (list, tuple)):
|
||||
converted = [_to_cpu_for_rpc(x) for x in obj]
|
||||
return type(obj)(converted) if isinstance(obj, tuple) else converted
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cpu_for_rpc(v) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
class ModelSamplingRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "modelsampling"
|
||||
|
||||
@@ -199,14 +236,70 @@ class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
|
||||
rpc = self._get_rpc()
|
||||
method = getattr(rpc, method_name)
|
||||
result = method(self._instance_id, *args)
|
||||
timeout_ms = self._rpc_timeout_ms()
|
||||
start_epoch = time.time()
|
||||
start_perf = time.perf_counter()
|
||||
thread_id = threading.get_ident()
|
||||
call_id = "%s:%s:%s:%.6f" % (
|
||||
self._instance_id,
|
||||
method_name,
|
||||
thread_id,
|
||||
start_perf,
|
||||
)
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_start method=%s instance_id=%s call_id=%s start_ts=%.6f thread=%s timeout_ms=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
start_epoch,
|
||||
thread_id,
|
||||
timeout_ms,
|
||||
)
|
||||
if asyncio.iscoroutine(result):
|
||||
result = asyncio.wait_for(result, timeout=timeout_ms / 1000.0)
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
return run_coro_in_new_loop(result)
|
||||
out = run_coro_in_new_loop(result)
|
||||
except RuntimeError:
|
||||
loop = get_thread_loop()
|
||||
return loop.run_until_complete(result)
|
||||
return result
|
||||
out = loop.run_until_complete(result)
|
||||
else:
|
||||
out = result
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_after_await method=%s instance_id=%s call_id=%s out=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
_describe_value(out),
|
||||
)
|
||||
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_end method=%s instance_id=%s call_id=%s elapsed_ms=%.3f thread=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
elapsed_ms,
|
||||
thread_id,
|
||||
)
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_return method=%s instance_id=%s call_id=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _rpc_timeout_ms() -> int:
|
||||
raw = os.environ.get(
|
||||
"COMFY_ISOLATION_MODEL_SAMPLING_RPC_TIMEOUT_MS",
|
||||
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "30000"),
|
||||
)
|
||||
try:
|
||||
timeout_ms = int(raw)
|
||||
except ValueError:
|
||||
timeout_ms = 30000
|
||||
return max(1, timeout_ms)
|
||||
|
||||
@property
|
||||
def sigma_min(self) -> Any:
|
||||
@@ -235,10 +328,24 @@ class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
|
||||
def noise_scaling(
|
||||
self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False
|
||||
) -> Any:
|
||||
return self._call("noise_scaling", sigma, noise, latent_image, max_denoise)
|
||||
preferred_device = _prefer_device(noise, latent_image)
|
||||
out = self._call(
|
||||
"noise_scaling",
|
||||
_to_cpu_for_rpc(sigma),
|
||||
_to_cpu_for_rpc(noise),
|
||||
_to_cpu_for_rpc(latent_image),
|
||||
max_denoise,
|
||||
)
|
||||
return _to_device(out, preferred_device)
|
||||
|
||||
def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any:
|
||||
return self._call("inverse_noise_scaling", sigma, latent)
|
||||
preferred_device = _prefer_device(latent)
|
||||
out = self._call(
|
||||
"inverse_noise_scaling",
|
||||
_to_cpu_for_rpc(sigma),
|
||||
_to_cpu_for_rpc(latent),
|
||||
)
|
||||
return _to_device(out, preferred_device)
|
||||
|
||||
def timestep(self, sigma: Any) -> Any:
|
||||
return self._call("timestep", sigma)
|
||||
|
||||
@@ -2,9 +2,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
|
||||
|
||||
@@ -119,6 +121,24 @@ def set_global_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||
class BaseProxy(Generic[T]):
|
||||
_registry_class: type = BaseRegistry # type: ignore[type-arg]
|
||||
__module__: str = "comfy.isolation.proxies.base"
|
||||
_TIMEOUT_RPC_METHODS = frozenset(
|
||||
{
|
||||
"partially_load",
|
||||
"partially_unload",
|
||||
"load",
|
||||
"patch_model",
|
||||
"unpatch_model",
|
||||
"inner_model_apply_model",
|
||||
"memory_required",
|
||||
"model_dtype",
|
||||
"inner_model_memory_required",
|
||||
"inner_model_extra_conds_shapes",
|
||||
"inner_model_extra_conds",
|
||||
"process_latent_in",
|
||||
"process_latent_out",
|
||||
"scale_latent_inpaint",
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -148,39 +168,89 @@ class BaseProxy(Generic[T]):
|
||||
)
|
||||
return self._rpc_caller
|
||||
|
||||
def _rpc_timeout_ms_for_method(self, method_name: str) -> Optional[int]:
|
||||
if method_name not in self._TIMEOUT_RPC_METHODS:
|
||||
return None
|
||||
try:
|
||||
timeout_ms = int(
|
||||
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "120000")
|
||||
)
|
||||
except ValueError:
|
||||
timeout_ms = 120000
|
||||
return max(1, timeout_ms)
|
||||
|
||||
def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
rpc = self._get_rpc()
|
||||
method = getattr(rpc, method_name)
|
||||
timeout_ms = self._rpc_timeout_ms_for_method(method_name)
|
||||
coro = method(self._instance_id, *args, **kwargs)
|
||||
if timeout_ms is not None:
|
||||
coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0)
|
||||
|
||||
# If we have a global loop (Main Thread Loop), use it for dispatch from worker threads
|
||||
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
|
||||
try:
|
||||
# If we are already in the global loop, we can't block on it?
|
||||
# Actually, this method is synchronous (__getattr__ -> lambda).
|
||||
# If called from async context in main loop, we need to handle that.
|
||||
curr_loop = asyncio.get_running_loop()
|
||||
if curr_loop is _GLOBAL_LOOP:
|
||||
# We are in the main loop. We cannot await/block here if we are just a sync function.
|
||||
# But proxies are often called from sync code.
|
||||
# If called from sync code in main loop, creating a new loop is bad.
|
||||
# But we can't await `coro`.
|
||||
# This implies proxies MUST be awaited if called from async context?
|
||||
# Existing code used `run_coro_in_new_loop` which is weird.
|
||||
# Let's trust that if we are in a thread (RuntimeError on get_running_loop),
|
||||
# we use run_coroutine_threadsafe.
|
||||
pass
|
||||
except RuntimeError:
|
||||
# No running loop - we are in a worker thread.
|
||||
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
|
||||
return future.result()
|
||||
start_epoch = time.time()
|
||||
start_perf = time.perf_counter()
|
||||
thread_id = threading.get_ident()
|
||||
try:
|
||||
running_loop = asyncio.get_running_loop()
|
||||
loop_id: Optional[int] = id(running_loop)
|
||||
except RuntimeError:
|
||||
loop_id = None
|
||||
logger.debug(
|
||||
"ISO:rpc_start proxy=%s method=%s instance_id=%s start_ts=%.6f "
|
||||
"thread=%s loop=%s timeout_ms=%s",
|
||||
self.__class__.__name__,
|
||||
method_name,
|
||||
self._instance_id,
|
||||
start_epoch,
|
||||
thread_id,
|
||||
loop_id,
|
||||
timeout_ms,
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
return run_coro_in_new_loop(coro)
|
||||
except RuntimeError:
|
||||
loop = get_thread_loop()
|
||||
return loop.run_until_complete(coro)
|
||||
# If we have a global loop (Main Thread Loop), use it for dispatch from worker threads
|
||||
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
|
||||
try:
|
||||
curr_loop = asyncio.get_running_loop()
|
||||
if curr_loop is _GLOBAL_LOOP:
|
||||
pass
|
||||
except RuntimeError:
|
||||
# No running loop - we are in a worker thread.
|
||||
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
|
||||
return future.result(
|
||||
timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
return run_coro_in_new_loop(coro)
|
||||
except RuntimeError:
|
||||
loop = get_thread_loop()
|
||||
return loop.run_until_complete(coro)
|
||||
except asyncio.TimeoutError as exc:
|
||||
raise TimeoutError(
|
||||
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
|
||||
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
|
||||
) from exc
|
||||
except concurrent.futures.TimeoutError as exc:
|
||||
raise TimeoutError(
|
||||
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
|
||||
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
|
||||
) from exc
|
||||
finally:
|
||||
end_epoch = time.time()
|
||||
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
|
||||
logger.debug(
|
||||
"ISO:rpc_end proxy=%s method=%s instance_id=%s end_ts=%.6f "
|
||||
"elapsed_ms=%.3f thread=%s loop=%s",
|
||||
self.__class__.__name__,
|
||||
method_name,
|
||||
self._instance_id,
|
||||
end_epoch,
|
||||
elapsed_ms,
|
||||
thread_id,
|
||||
loop_id,
|
||||
)
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
return {"_instance_id": self._instance_id}
|
||||
|
||||
@@ -192,7 +192,7 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
|
||||
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if isolation_active:
|
||||
target_device = sigmas.device
|
||||
if x.device != target_device:
|
||||
|
||||
@@ -1621,3 +1621,118 @@ class HumoWanModel(WanModel):
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
class SCAILWanModel(WanModel):
|
||||
def __init__(self, model_type="scail", patch_size=(1, 2, 2), in_dim=20, dim=5120, operations=None, device=None, dtype=None, **kwargs):
|
||||
super().__init__(model_type='i2v', patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
|
||||
|
||||
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
||||
|
||||
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs):
|
||||
|
||||
if reference_latent is not None:
|
||||
x = torch.cat((reference_latent, x), dim=2)
|
||||
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
transformer_options["grid_sizes"] = grid_sizes
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
scail_pose_seq_len = 0
|
||||
if pose_latents is not None:
|
||||
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
|
||||
scail_x = scail_x.flatten(2).transpose(1, 2)
|
||||
scail_pose_seq_len = scail_x.shape[1]
|
||||
x = torch.cat([x, scail_x], dim=1)
|
||||
del scail_x
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
||||
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
context_img_len = None
|
||||
if clip_fea is not None:
|
||||
if self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.cat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
if scail_pose_seq_len > 0:
|
||||
x = x[:, :-scail_pose_seq_len]
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
|
||||
if reference_latent is not None:
|
||||
x = x[:, :, reference_latent.shape[2]:]
|
||||
|
||||
return x
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}):
|
||||
main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
|
||||
|
||||
if pose_latents is None:
|
||||
return main_freqs
|
||||
|
||||
ref_t_patches = 0
|
||||
if reference_latent is not None:
|
||||
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
|
||||
|
||||
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
|
||||
|
||||
# if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames
|
||||
h_scale = h / H_pose
|
||||
w_scale = w / W_pose
|
||||
|
||||
# 120 w-offset and shift 0.5 to place positions at midpoints (0.5, 2.5, ...) to match the original code
|
||||
h_shift = (h_scale - 1) / 2
|
||||
w_shift = (w_scale - 1) / 2
|
||||
pose_transformer_options = {"rope_options": {"shift_y": h_shift, "shift_x": 120.0 + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
|
||||
pose_freqs = super().rope_encode(F_pose, H_pose, W_pose, t_start=t_start+ref_t_patches, device=device, dtype=dtype, transformer_options=pose_transformer_options)
|
||||
|
||||
return torch.cat([main_freqs, pose_freqs], dim=1)
|
||||
|
||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
if pose_latents is not None:
|
||||
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
|
||||
|
||||
t_len = t
|
||||
if time_dim_concat is not None:
|
||||
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
||||
x = torch.cat([x, time_dim_concat], dim=2)
|
||||
t_len = x.shape[2]
|
||||
|
||||
reference_latent = None
|
||||
if "reference_latent" in kwargs:
|
||||
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
|
||||
t_len += reference_latent.shape[2]
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
@@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1
|
||||
import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||
import torch
|
||||
import logging
|
||||
import os
|
||||
import comfy.ldm.lightricks.av_model
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from comfy.ldm.cascade.stage_c import StageC
|
||||
@@ -76,6 +77,7 @@ class ModelType(Enum):
|
||||
FLUX = 8
|
||||
IMG_TO_IMG = 9
|
||||
FLOW_COSMOS = 10
|
||||
IMG_TO_IMG_FLOW = 11
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
@@ -108,17 +110,23 @@ def model_sampling(model_config, model_type):
|
||||
elif model_type == ModelType.FLOW_COSMOS:
|
||||
c = comfy.model_sampling.COSMOS_RFLOW
|
||||
s = comfy.model_sampling.ModelSamplingCosmosRFlow
|
||||
elif model_type == ModelType.IMG_TO_IMG_FLOW:
|
||||
c = comfy.model_sampling.IMG_TO_IMG_FLOW
|
||||
|
||||
from comfy.cli_args import args
|
||||
isolation_runtime_enabled = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
class ModelSampling(s, c):
|
||||
def __reduce__(self):
|
||||
"""Ensure pickling yields a proxy instead of failing on local class."""
|
||||
try:
|
||||
from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy
|
||||
registry = ModelSamplingRegistry()
|
||||
ms_id = registry.register(self)
|
||||
return (ModelSamplingProxy, (ms_id,))
|
||||
except Exception as exc:
|
||||
raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc
|
||||
if isolation_runtime_enabled:
|
||||
def __reduce__(self):
|
||||
"""Ensure pickling yields a proxy instead of failing on local class."""
|
||||
try:
|
||||
from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy
|
||||
registry = ModelSamplingRegistry()
|
||||
ms_id = registry.register(self)
|
||||
return (ModelSamplingProxy, (ms_id,))
|
||||
except Exception as exc:
|
||||
raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc
|
||||
|
||||
return ModelSampling(model_config)
|
||||
|
||||
@@ -998,6 +1006,10 @@ class LTXV(BaseModel):
|
||||
if keyframe_idxs is not None:
|
||||
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
||||
|
||||
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
||||
if guide_attention_entries is not None:
|
||||
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||
@@ -1050,6 +1062,10 @@ class LTXAV(BaseModel):
|
||||
if latent_shapes is not None:
|
||||
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
||||
|
||||
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
||||
if guide_attention_entries is not None:
|
||||
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||
@@ -1493,6 +1509,50 @@ class WAN22(WAN21):
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
class WAN21_FlowRVS(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
|
||||
model_config.unet_config["model_type"] = "t2v"
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
class WAN21_SCAIL(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel)
|
||||
self.memory_usage_factor_conds = ("reference_latent", "pose_latents")
|
||||
self.memory_usage_shape_process = {"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
|
||||
reference_latents = kwargs.get("reference_latents", None)
|
||||
if reference_latents is not None:
|
||||
ref_latent = self.process_latent_in(reference_latents[-1])
|
||||
ref_mask = torch.ones_like(ref_latent[:, :4])
|
||||
ref_latent = torch.cat([ref_latent, ref_mask], dim=1)
|
||||
out['reference_latent'] = comfy.conds.CONDRegular(ref_latent)
|
||||
|
||||
pose_latents = kwargs.get("pose_video_latent", None)
|
||||
if pose_latents is not None:
|
||||
pose_latents = self.process_latent_in(pose_latents)
|
||||
pose_mask = torch.ones_like(pose_latents[:, :4])
|
||||
pose_latents = torch.cat([pose_latents, pose_mask], dim=1)
|
||||
out['pose_latents'] = comfy.conds.CONDRegular(pose_latents)
|
||||
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
|
||||
pose_latents = kwargs.get("pose_video_latent", None)
|
||||
if pose_latents is not None:
|
||||
out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]]
|
||||
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||
|
||||
@@ -423,7 +423,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||
return dit_config
|
||||
|
||||
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
|
||||
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "lumina2"
|
||||
dit_config["patch_size"] = 2
|
||||
@@ -498,6 +498,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["model_type"] = "humo"
|
||||
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "animate"
|
||||
elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "scail"
|
||||
else:
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "i2v"
|
||||
@@ -531,8 +533,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1
|
||||
|
||||
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys and f"{key_prefix}blocks.0.attn1.k_norm.weight" in state_dict_keys: # Hunyuan 3D 2.1
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "hunyuan3d2_1"
|
||||
dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
|
||||
@@ -1053,6 +1054,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
|
||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix)
|
||||
elif 'noise_refiner.0.attention.norm_k.weight' in state_dict:
|
||||
n_layers = count_blocks(state_dict, 'layers.{}.')
|
||||
dim = state_dict['noise_refiner.0.attention.to_k.weight'].shape[0]
|
||||
sd_map = comfy.utils.z_image_to_diffusers({"n_layers": n_layers, "dim": dim}, output_prefix=output_prefix)
|
||||
for k in state_dict: # For zeta chroma
|
||||
if k not in sd_map:
|
||||
sd_map[k] = k
|
||||
elif 'x_embedder.weight' in state_dict: #Flux
|
||||
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||
|
||||
@@ -180,6 +180,14 @@ def is_ixuca():
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_wsl():
|
||||
version = platform.uname().release
|
||||
if version.endswith("-Microsoft"):
|
||||
return True
|
||||
elif version.endswith("microsoft-standard-WSL2"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_torch_device():
|
||||
global directml_enabled
|
||||
global cpu_state
|
||||
@@ -475,6 +483,9 @@ except:
|
||||
|
||||
current_loaded_models = []
|
||||
|
||||
def _isolation_mode_enabled():
|
||||
return args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
def module_size(module):
|
||||
module_mem = 0
|
||||
sd = module.state_dict()
|
||||
@@ -554,8 +565,9 @@ class LoadedModel:
|
||||
if freed >= memory_to_free:
|
||||
return False
|
||||
self.model.detach(unpatch_weights)
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
if self.model_finalizer is not None:
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
self.real_model = None
|
||||
return True
|
||||
|
||||
@@ -569,14 +581,15 @@ class LoadedModel:
|
||||
if self._patcher_finalizer is not None:
|
||||
self._patcher_finalizer.detach()
|
||||
|
||||
def dead_state(self):
|
||||
model_ref_gone = self.model is None
|
||||
real_model_ref = self.real_model
|
||||
real_model_ref_gone = callable(real_model_ref) and real_model_ref() is None
|
||||
return model_ref_gone, real_model_ref_gone
|
||||
|
||||
def is_dead(self):
|
||||
# Model is dead if the weakref to model has been garbage collected
|
||||
# This can happen with ModelPatcherProxy objects between isolated workflows
|
||||
if self.model is None:
|
||||
return True
|
||||
if self.real_model is None:
|
||||
return False
|
||||
return self.real_model() is None
|
||||
model_ref_gone, real_model_ref_gone = self.dead_state()
|
||||
return model_ref_gone or real_model_ref_gone
|
||||
|
||||
|
||||
def use_more_memory(extra_memory, loaded_models, device):
|
||||
@@ -622,7 +635,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
isolation_active = os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
|
||||
isolation_active = _isolation_mode_enabled()
|
||||
|
||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
@@ -649,12 +662,11 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
||||
if not DISABLE_SMART_MEMORY:
|
||||
memory_to_free = memory_required - get_free_memory(device)
|
||||
ram_to_free = ram_required - get_free_ram()
|
||||
|
||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||
#don't actually unload dynamic models for the sake of other dynamic models
|
||||
#as that works on-demand.
|
||||
memory_required -= current_loaded_models[i].model.loaded_size()
|
||||
memory_to_free = 0
|
||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||
#don't actually unload dynamic models for the sake of other dynamic models
|
||||
#as that works on-demand.
|
||||
memory_required -= current_loaded_models[i].model.loaded_size()
|
||||
memory_to_free = 0
|
||||
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
unloaded_model.append(i)
|
||||
@@ -728,7 +740,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
for i in to_unload:
|
||||
model_to_unload = current_loaded_models.pop(i)
|
||||
model_to_unload.model.detach(unpatch_all=False)
|
||||
model_to_unload.model_finalizer.detach()
|
||||
if model_to_unload.model_finalizer is not None:
|
||||
model_to_unload.model_finalizer.detach()
|
||||
model_to_unload.model_finalizer = None
|
||||
|
||||
|
||||
total_memory_required = {}
|
||||
@@ -792,21 +806,55 @@ def loaded_models(only_currently_used=False):
|
||||
|
||||
def cleanup_models_gc():
|
||||
reset_cast_buffers()
|
||||
if not _isolation_mode_enabled():
|
||||
dead_found = False
|
||||
for i in range(len(current_loaded_models)):
|
||||
if current_loaded_models[i].is_dead():
|
||||
dead_found = True
|
||||
break
|
||||
|
||||
if dead_found:
|
||||
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
|
||||
gc.collect()
|
||||
soft_empty_cache()
|
||||
|
||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
|
||||
leaked = current_loaded_models.pop(i)
|
||||
model_obj = getattr(leaked, "model", None)
|
||||
if model_obj is not None:
|
||||
cleanup = getattr(model_obj, "cleanup", None)
|
||||
if callable(cleanup):
|
||||
cleanup()
|
||||
return
|
||||
|
||||
dead_found = False
|
||||
has_real_model_leak = False
|
||||
for i in range(len(current_loaded_models)):
|
||||
if current_loaded_models[i].is_dead():
|
||||
model_ref_gone, real_model_ref_gone = current_loaded_models[i].dead_state()
|
||||
if model_ref_gone or real_model_ref_gone:
|
||||
dead_found = True
|
||||
break
|
||||
if real_model_ref_gone and not model_ref_gone:
|
||||
has_real_model_leak = True
|
||||
|
||||
if dead_found:
|
||||
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
|
||||
if has_real_model_leak:
|
||||
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
|
||||
else:
|
||||
logging.debug("Cleaning stale loaded-model entries with released patcher references.")
|
||||
gc.collect()
|
||||
soft_empty_cache()
|
||||
|
||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
|
||||
model_ref_gone, real_model_ref_gone = cur.dead_state()
|
||||
if model_ref_gone or real_model_ref_gone:
|
||||
if real_model_ref_gone and not model_ref_gone:
|
||||
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
|
||||
else:
|
||||
logging.debug("Cleaning stale loaded-model entry with released patcher reference.")
|
||||
leaked = current_loaded_models.pop(i)
|
||||
model_obj = getattr(leaked, "model", None)
|
||||
if model_obj is not None:
|
||||
@@ -824,7 +872,11 @@ def archive_model_dtypes(model):
|
||||
def cleanup_models():
|
||||
to_delete = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if current_loaded_models[i].real_model() is None:
|
||||
real_model_ref = current_loaded_models[i].real_model
|
||||
if real_model_ref is None:
|
||||
to_delete = [i] + to_delete
|
||||
continue
|
||||
if callable(real_model_ref) and real_model_ref() is None:
|
||||
to_delete = [i] + to_delete
|
||||
|
||||
for i in to_delete:
|
||||
|
||||
@@ -308,15 +308,22 @@ class ModelPatcher:
|
||||
def get_free_memory(self, device):
|
||||
return comfy.model_management.get_free_memory(device)
|
||||
|
||||
def clone(self, disable_dynamic=False):
|
||||
def get_clone_model_override(self):
|
||||
return self.model, (self.backup, self.object_patches_backup, self.pinned)
|
||||
|
||||
def clone(self, disable_dynamic=False, model_override=None):
|
||||
class_ = self.__class__
|
||||
model = self.model
|
||||
if self.is_dynamic() and disable_dynamic:
|
||||
class_ = ModelPatcher
|
||||
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
|
||||
model = temp_model_patcher.model
|
||||
if model_override is None:
|
||||
if self.cached_patcher_init is None:
|
||||
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
|
||||
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
|
||||
model_override = temp_model_patcher.get_clone_model_override()
|
||||
if model_override is None:
|
||||
model_override = self.get_clone_model_override()
|
||||
|
||||
n = class_(model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||
n = class_(model_override[0], self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
@@ -325,13 +332,12 @@ class ModelPatcher:
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
||||
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
|
||||
n.backup = self.backup
|
||||
n.object_patches_backup = self.object_patches_backup
|
||||
n.parent = self
|
||||
n.pinned = self.pinned
|
||||
|
||||
n.force_cast_weights = self.force_cast_weights
|
||||
|
||||
n.backup, n.object_patches_backup, n.pinned = model_override[1]
|
||||
|
||||
# attachments
|
||||
n.attachments = {}
|
||||
for k in self.attachments:
|
||||
@@ -1435,6 +1441,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
del self.model.model_loaded_weight_memory
|
||||
if not hasattr(self.model, "dynamic_vbars"):
|
||||
self.model.dynamic_vbars = {}
|
||||
self.non_dynamic_delegate_model = None
|
||||
assert load_device is not None
|
||||
|
||||
def is_dynamic(self):
|
||||
@@ -1669,4 +1676,10 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
||||
pass
|
||||
|
||||
def get_non_dynamic_delegate(self):
|
||||
model_patcher = self.clone(disable_dynamic=True, model_override=self.non_dynamic_delegate_model)
|
||||
self.non_dynamic_delegate_model = model_patcher.get_clone_model_override()
|
||||
return model_patcher
|
||||
|
||||
|
||||
CoreModelPatcher = ModelPatcher
|
||||
|
||||
@@ -66,6 +66,18 @@ def convert_cond(cond):
|
||||
out.append(temp)
|
||||
return out
|
||||
|
||||
def cond_has_hooks(cond):
|
||||
for c in cond:
|
||||
temp = c[1]
|
||||
if "hooks" in temp:
|
||||
return True
|
||||
if "control" in temp:
|
||||
control = temp["control"]
|
||||
extra_hooks = control.get_extra_hooks()
|
||||
if len(extra_hooks) > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_additional_models(conds, dtype):
|
||||
"""loads additional models in conditioning"""
|
||||
cnets: list[ControlBase] = []
|
||||
|
||||
@@ -212,10 +212,11 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
|
||||
_calc_cond_batch,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
result = executor.execute(model, conds, x_in, timestep, model_options)
|
||||
return result
|
||||
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
|
||||
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
@@ -272,7 +273,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
for k, v in to_run[tt][0].conditioning.items():
|
||||
cond_shapes[k].append(v.size())
|
||||
|
||||
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
|
||||
memory_required = model.memory_required(input_shape, cond_shapes=cond_shapes)
|
||||
if memory_required * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
@@ -411,7 +413,7 @@ class KSamplerX0Inpaint:
|
||||
self.inner_model = model
|
||||
self.sigmas = sigmas
|
||||
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
|
||||
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
|
||||
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if denoise_mask is not None:
|
||||
if isolation_active and denoise_mask.device != x.device:
|
||||
denoise_mask = denoise_mask.to(x.device)
|
||||
@@ -777,7 +779,11 @@ class KSAMPLER(Sampler):
|
||||
else:
|
||||
model_k.noise = noise
|
||||
|
||||
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas))
|
||||
max_denoise = self.max_denoise(model_wrap, sigmas)
|
||||
model_sampling = model_wrap.inner_model.model_sampling
|
||||
noise = model_sampling.noise_scaling(
|
||||
sigmas[0], noise, latent_image, max_denoise
|
||||
)
|
||||
|
||||
k_callback = None
|
||||
total_steps = len(sigmas) - 1
|
||||
@@ -982,6 +988,8 @@ class CFGGuider:
|
||||
|
||||
def inner_set_conds(self, conds):
|
||||
for k in conds:
|
||||
if self.model_patcher.is_dynamic() and comfy.sampler_helpers.cond_has_hooks(conds[k]):
|
||||
self.model_patcher = self.model_patcher.get_non_dynamic_delegate()
|
||||
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
|
||||
38
comfy/sd.py
38
comfy/sd.py
@@ -204,7 +204,7 @@ def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip
|
||||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}, disable_dynamic=False):
|
||||
if no_init:
|
||||
return
|
||||
params = target.params.copy()
|
||||
@@ -233,7 +233,8 @@ class CLIP:
|
||||
model_management.archive_model_dtypes(self.cond_stage_model)
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
||||
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
#Match torch.float32 hardcode upcast in TE implemention
|
||||
self.patcher.set_model_compute_dtype(torch.float32)
|
||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||
@@ -267,9 +268,9 @@ class CLIP:
|
||||
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
||||
self.tokenizer_options = {}
|
||||
|
||||
def clone(self):
|
||||
def clone(self, disable_dynamic=False):
|
||||
n = CLIP(no_init=True)
|
||||
n.patcher = self.patcher.clone()
|
||||
n.patcher = self.patcher.clone(disable_dynamic=disable_dynamic)
|
||||
n.cond_stage_model = self.cond_stage_model
|
||||
n.tokenizer = self.tokenizer
|
||||
n.layer_idx = self.layer_idx
|
||||
@@ -1164,14 +1165,21 @@ class CLIPType(Enum):
|
||||
LONGCAT_IMAGE = 26
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
|
||||
def load_clip_model_patcher(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
|
||||
clip = load_clip(ckpt_paths, embedding_directory, clip_type, model_options, disable_dynamic)
|
||||
return clip.patcher
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
|
||||
clip_data = []
|
||||
for p in ckpt_paths:
|
||||
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
||||
if model_options.get("custom_operations", None) is None:
|
||||
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
|
||||
clip_data.append(sd)
|
||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||
clip = load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options, disable_dynamic=disable_dynamic)
|
||||
clip.patcher.cached_patcher_init = (load_clip_model_patcher, (ckpt_paths, embedding_directory, clip_type, model_options))
|
||||
return clip
|
||||
|
||||
|
||||
class TEModel(Enum):
|
||||
@@ -1276,7 +1284,7 @@ def llama_detect(clip_data):
|
||||
|
||||
return {}
|
||||
|
||||
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
|
||||
clip_data = state_dicts
|
||||
|
||||
class EmptyClass:
|
||||
@@ -1496,7 +1504,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
parameters += comfy.utils.calculate_parameters(c)
|
||||
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options, disable_dynamic=disable_dynamic)
|
||||
return clip
|
||||
|
||||
def load_gligen(ckpt_path):
|
||||
@@ -1541,8 +1549,10 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||
if out is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||
if output_model:
|
||||
if output_model and out[0] is not None:
|
||||
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||
if output_clip and out[1] is not None:
|
||||
out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||
return out
|
||||
|
||||
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
@@ -1553,6 +1563,14 @@ def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None,
|
||||
disable_dynamic=disable_dynamic)
|
||||
return model
|
||||
|
||||
def load_checkpoint_guess_config_clip_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
_, clip, *_ = load_checkpoint_guess_config(ckpt_path, False, True, False,
|
||||
embedding_directory=embedding_directory, output_model=False,
|
||||
model_options=model_options,
|
||||
te_model_options=te_model_options,
|
||||
disable_dynamic=disable_dynamic)
|
||||
return clip.patcher
|
||||
|
||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False):
|
||||
clip = None
|
||||
clipvision = None
|
||||
@@ -1638,7 +1656,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
clip_sd = model_config.process_clip_state_dict(sd)
|
||||
if len(clip_sd) > 0:
|
||||
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options, disable_dynamic=disable_dynamic)
|
||||
else:
|
||||
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
||||
|
||||
|
||||
@@ -1268,6 +1268,16 @@ class WAN21_FlowRVS(WAN21_T2V):
|
||||
out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
class WAN21_SCAIL(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "scail",
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan3d2",
|
||||
@@ -1710,6 +1720,6 @@ class LongCatImage(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -789,8 +789,6 @@ class GeminiImage2(IO.ComfyNode):
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
if model == "Nano Banana 2 (Gemini 3.1 Flash Image)":
|
||||
model = "gemini-3.1-flash-image-preview"
|
||||
if response_modalities == "IMAGE+TEXT":
|
||||
raise ValueError("IMAGE+TEXT is not currently available for the Nano Banana 2 model.")
|
||||
|
||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||
if images is not None:
|
||||
@@ -895,7 +893,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"response_modalities",
|
||||
options=["IMAGE"],
|
||||
options=["IMAGE", "IMAGE+TEXT"],
|
||||
advanced=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
@@ -925,6 +923,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
IO.String.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
|
||||
@@ -20,7 +20,7 @@ class JobStatus:
|
||||
|
||||
|
||||
# Media types that can be previewed in the frontend
|
||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'})
|
||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'})
|
||||
|
||||
# 3D file extensions for preview fallback (no dedicated media_type exists)
|
||||
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'})
|
||||
@@ -75,6 +75,23 @@ def normalize_outputs(outputs: dict) -> dict:
|
||||
normalized[node_id] = normalized_node
|
||||
return normalized
|
||||
|
||||
# Text preview truncation limit (1024 characters) to prevent preview_output bloat
|
||||
TEXT_PREVIEW_MAX_LENGTH = 1024
|
||||
|
||||
|
||||
def _create_text_preview(value: str) -> dict:
|
||||
"""Create a text preview dict with optional truncation.
|
||||
|
||||
Returns:
|
||||
dict with 'content' and optionally 'truncated' flag
|
||||
"""
|
||||
if len(value) <= TEXT_PREVIEW_MAX_LENGTH:
|
||||
return {'content': value}
|
||||
return {
|
||||
'content': value[:TEXT_PREVIEW_MAX_LENGTH],
|
||||
'truncated': True
|
||||
}
|
||||
|
||||
|
||||
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
||||
"""Extract create_time and workflow_id from extra_data.
|
||||
@@ -221,23 +238,43 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
|
||||
continue
|
||||
|
||||
for item in items:
|
||||
normalized = normalize_output_item(item)
|
||||
if normalized is None:
|
||||
continue
|
||||
if not isinstance(item, dict):
|
||||
# Handle text outputs (non-dict items like strings or tuples)
|
||||
normalized = normalize_output_item(item)
|
||||
if normalized is None:
|
||||
# Not a 3D file string — check for text preview
|
||||
if media_type == 'text':
|
||||
count += 1
|
||||
if preview_output is None:
|
||||
if isinstance(item, tuple):
|
||||
text_value = item[0] if item else ''
|
||||
else:
|
||||
text_value = str(item)
|
||||
text_preview = _create_text_preview(text_value)
|
||||
enriched = {
|
||||
**text_preview,
|
||||
'nodeId': node_id,
|
||||
'mediaType': media_type
|
||||
}
|
||||
if fallback_preview is None:
|
||||
fallback_preview = enriched
|
||||
continue
|
||||
# normalize_output_item returned a dict (e.g. 3D file)
|
||||
item = normalized
|
||||
|
||||
count += 1
|
||||
|
||||
if preview_output is not None:
|
||||
continue
|
||||
|
||||
if isinstance(normalized, dict) and is_previewable(media_type, normalized):
|
||||
if is_previewable(media_type, item):
|
||||
enriched = {
|
||||
**normalized,
|
||||
**item,
|
||||
'nodeId': node_id,
|
||||
}
|
||||
if 'mediaType' not in normalized:
|
||||
if 'mediaType' not in item:
|
||||
enriched['mediaType'] = media_type
|
||||
if normalized.get('type') == 'output':
|
||||
if item.get('type') == 'output':
|
||||
preview_output = enriched
|
||||
elif fallback_preview is None:
|
||||
fallback_preview = enriched
|
||||
|
||||
@@ -248,7 +248,7 @@ class SetClipHooks:
|
||||
|
||||
def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
|
||||
if hooks is not None:
|
||||
clip = clip.clone()
|
||||
clip = clip.clone(disable_dynamic=True)
|
||||
if apply_to_conds:
|
||||
clip.apply_hooks_to_conds = hooks
|
||||
clip.patcher.forced_hooks = hooks.clone()
|
||||
|
||||
@@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Mahiro",
|
||||
display_name="Mahiro CFG",
|
||||
display_name="Positive-Biased Guidance",
|
||||
category="_for_testing",
|
||||
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
|
||||
inputs=[
|
||||
@@ -20,27 +20,35 @@ class Mahiro(io.ComfyNode):
|
||||
io.Model.Output(display_name="patched_model"),
|
||||
],
|
||||
is_experimental=True,
|
||||
search_aliases=[
|
||||
"mahiro",
|
||||
"mahiro cfg",
|
||||
"similarity-adaptive guidance",
|
||||
"positive-biased cfg",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
|
||||
def mahiro_normd(args):
|
||||
scale: float = args['cond_scale']
|
||||
cond_p: torch.Tensor = args['cond_denoised']
|
||||
uncond_p: torch.Tensor = args['uncond_denoised']
|
||||
#naive leap
|
||||
scale: float = args["cond_scale"]
|
||||
cond_p: torch.Tensor = args["cond_denoised"]
|
||||
uncond_p: torch.Tensor = args["uncond_denoised"]
|
||||
# naive leap
|
||||
leap = cond_p * scale
|
||||
#sim with uncond leap
|
||||
# sim with uncond leap
|
||||
u_leap = uncond_p * scale
|
||||
cfg = args["denoised"]
|
||||
merge = (leap + cfg) / 2
|
||||
normu = torch.sqrt(u_leap.abs()) * u_leap.sign()
|
||||
normm = torch.sqrt(merge.abs()) * merge.sign()
|
||||
sim = F.cosine_similarity(normu, normm).mean()
|
||||
simsc = 2 * (sim+1)
|
||||
wm = (simsc*cfg + (4-simsc)*leap) / 4
|
||||
simsc = 2 * (sim + 1)
|
||||
wm = (simsc * cfg + (4 - simsc) * leap) / 4
|
||||
return wm
|
||||
|
||||
m.set_model_sampler_post_cfg_function(mahiro_normd)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
@@ -1456,6 +1456,63 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
|
||||
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
|
||||
|
||||
|
||||
class WanSCAILToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanSCAILToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
io.Image.Input("reference_image", optional=True),
|
||||
io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."),
|
||||
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
|
||||
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."),
|
||||
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, reference_image=None, clip_vision_output=None, pose_video=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
|
||||
ref_latent = None
|
||||
if reference_image is not None:
|
||||
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
ref_latent = vae.encode(reference_image[:, :, :, :3])
|
||||
|
||||
if ref_latent is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
|
||||
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True)
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
if pose_video is not None:
|
||||
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
|
||||
pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength
|
||||
positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
||||
negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
|
||||
class WanExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@@ -1476,6 +1533,7 @@ class WanExtension(ComfyExtension):
|
||||
WanAnimateToVideo,
|
||||
Wan22ImageToVideoLatent,
|
||||
WanInfiniteTalkToVideo,
|
||||
WanSCAILToVideo,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> WanExtension:
|
||||
|
||||
69
execution.py
69
execution.py
@@ -43,6 +43,8 @@ from comfy_execution.utils import CurrentNodeContext
|
||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||
from comfy_api.latest import io, _io
|
||||
|
||||
_AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED = False
|
||||
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
SUCCESS = 0
|
||||
@@ -540,7 +542,17 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
if args.verbose == "DEBUG":
|
||||
comfy_aimdo.control.analyze()
|
||||
comfy.model_management.reset_cast_buffers()
|
||||
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
|
||||
vbar_lib = getattr(comfy_aimdo.model_vbar, "lib", None)
|
||||
if vbar_lib is not None:
|
||||
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
|
||||
else:
|
||||
global _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED
|
||||
if not _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED:
|
||||
logging.warning(
|
||||
"DynamicVRAM backend unavailable for watermark reset; "
|
||||
"skipping vbar reset for this process."
|
||||
)
|
||||
_AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED = True
|
||||
|
||||
if has_pending_tasks:
|
||||
pending_async_nodes[unique_id] = output_data
|
||||
@@ -669,6 +681,8 @@ class PromptExecutor:
|
||||
self.success = True
|
||||
|
||||
async def _notify_execution_graph_safe(self, class_types: set[str], *, fail_loud: bool = False) -> None:
|
||||
if not args.use_process_isolation:
|
||||
return
|
||||
try:
|
||||
from comfy.isolation import notify_execution_graph
|
||||
await notify_execution_graph(class_types)
|
||||
@@ -678,12 +692,34 @@ class PromptExecutor:
|
||||
logging.debug("][ EX:notify_execution_graph failed", exc_info=True)
|
||||
|
||||
async def _flush_running_extensions_transport_state_safe(self) -> None:
|
||||
if not args.use_process_isolation:
|
||||
return
|
||||
try:
|
||||
from comfy.isolation import flush_running_extensions_transport_state
|
||||
await flush_running_extensions_transport_state()
|
||||
except Exception:
|
||||
logging.debug("][ EX:flush_running_extensions_transport_state failed", exc_info=True)
|
||||
|
||||
async def _wait_model_patcher_quiescence_safe(
|
||||
self,
|
||||
*,
|
||||
fail_loud: bool = False,
|
||||
timeout_ms: int = 120000,
|
||||
marker: str = "EX:wait_model_patcher_idle",
|
||||
) -> None:
|
||||
if not args.use_process_isolation:
|
||||
return
|
||||
try:
|
||||
from comfy.isolation import wait_for_model_patcher_quiescence
|
||||
|
||||
await wait_for_model_patcher_quiescence(
|
||||
timeout_ms=timeout_ms, fail_loud=fail_loud, marker=marker
|
||||
)
|
||||
except Exception:
|
||||
if fail_loud:
|
||||
raise
|
||||
logging.debug("][ EX:wait_model_patcher_quiescence failed", exc_info=True)
|
||||
|
||||
def add_message(self, event, data: dict, broadcast: bool):
|
||||
data = {
|
||||
**data,
|
||||
@@ -725,16 +761,17 @@ class PromptExecutor:
|
||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||
|
||||
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
# Update RPC event loops for all isolated extensions
|
||||
# This is critical for serial workflow execution - each asyncio.run() creates
|
||||
# a new event loop, and RPC instances must be updated to use it
|
||||
try:
|
||||
from comfy.isolation import update_rpc_event_loops
|
||||
update_rpc_event_loops()
|
||||
except ImportError:
|
||||
pass # Isolation not available
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"Failed to update RPC event loops: {e}")
|
||||
if args.use_process_isolation:
|
||||
# Update RPC event loops for all isolated extensions.
|
||||
# This is critical for serial workflow execution - each asyncio.run() creates
|
||||
# a new event loop, and RPC instances must be updated to use it.
|
||||
try:
|
||||
from comfy.isolation import update_rpc_event_loops
|
||||
update_rpc_event_loops()
|
||||
except ImportError:
|
||||
pass # Isolation not available
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"Failed to update RPC event loops: {e}")
|
||||
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
|
||||
@@ -754,6 +791,11 @@ class PromptExecutor:
|
||||
# Boundary cleanup runs at the start of the next workflow in
|
||||
# isolation mode, matching non-isolated "next prompt" timing.
|
||||
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
|
||||
await self._wait_model_patcher_quiescence_safe(
|
||||
fail_loud=False,
|
||||
timeout_ms=120000,
|
||||
marker="EX:boundary_cleanup_wait_idle",
|
||||
)
|
||||
await self._flush_running_extensions_transport_state_safe()
|
||||
comfy.model_management.unload_all_models()
|
||||
comfy.model_management.cleanup_models_gc()
|
||||
@@ -794,6 +836,11 @@ class PromptExecutor:
|
||||
for node_id in execution_list.pendingNodes.keys():
|
||||
class_type = dynamic_prompt.get_node(node_id)["class_type"]
|
||||
pending_class_types.add(class_type)
|
||||
await self._wait_model_patcher_quiescence_safe(
|
||||
fail_loud=True,
|
||||
timeout_ms=120000,
|
||||
marker="EX:notify_graph_wait_idle",
|
||||
)
|
||||
await self._notify_execution_graph_safe(pending_class_types, fail_loud=True)
|
||||
|
||||
while not execution_list.is_empty():
|
||||
|
||||
16
main.py
16
main.py
@@ -25,6 +25,15 @@ from app.assets.scanner import seed_assets
|
||||
import itertools
|
||||
import logging
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
if enables_dynamic_vram():
|
||||
if not comfy_aimdo.control.init():
|
||||
logging.warning(
|
||||
"DynamicVRAM requested, but comfy-aimdo failed to initialize early. "
|
||||
"Will fall back to legacy model loading if device init fails."
|
||||
)
|
||||
|
||||
if '--use-process-isolation' in sys.argv:
|
||||
from comfy.isolation import initialize_proxies
|
||||
initialize_proxies()
|
||||
@@ -208,11 +217,6 @@ import gc
|
||||
if 'torch' in sys.modules:
|
||||
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
if enables_dynamic_vram():
|
||||
comfy_aimdo.control.init()
|
||||
|
||||
import comfy.utils
|
||||
|
||||
if not IS_PYISOLATE_CHILD:
|
||||
@@ -228,7 +232,7 @@ if not IS_PYISOLATE_CHILD:
|
||||
import comfy.memory_management
|
||||
import comfy.model_patcher
|
||||
|
||||
if enables_dynamic_vram():
|
||||
if enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl():
|
||||
if comfy.model_management.torch_version_numeric < (2, 8):
|
||||
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import hashlib
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
@@ -21,6 +22,36 @@ def conditioning_set_values(conditioning, values={}, append=False):
|
||||
|
||||
return c
|
||||
|
||||
def conditioning_set_values_with_timestep_range(conditioning, values={}, start_percent=0.0, end_percent=1.0):
|
||||
"""
|
||||
Apply values to conditioning only during [start_percent, end_percent], keeping the
|
||||
original conditioning active outside that range. Respects existing per-entry ranges.
|
||||
"""
|
||||
if start_percent > end_percent:
|
||||
logging.warning(f"start_percent ({start_percent}) must be <= end_percent ({end_percent})")
|
||||
return conditioning
|
||||
|
||||
EPS = 1e-5 # the sampler gates entries with strict > / <, shift boundaries slightly to ensure only one conditioning is active per timestep
|
||||
c = []
|
||||
for t in conditioning:
|
||||
cond_start = t[1].get("start_percent", 0.0)
|
||||
cond_end = t[1].get("end_percent", 1.0)
|
||||
intersect_start = max(start_percent, cond_start)
|
||||
intersect_end = min(end_percent, cond_end)
|
||||
|
||||
if intersect_start >= intersect_end: # no overlap: emit unchanged
|
||||
c.append(t)
|
||||
continue
|
||||
|
||||
if intersect_start > cond_start: # part before the requested range
|
||||
c.extend(conditioning_set_values([t], {"start_percent": cond_start, "end_percent": intersect_start - EPS}))
|
||||
|
||||
c.extend(conditioning_set_values([t], {**values, "start_percent": intersect_start, "end_percent": intersect_end}))
|
||||
|
||||
if intersect_end < cond_end: # part after the requested range
|
||||
c.extend(conditioning_set_values([t], {"start_percent": intersect_end + EPS, "end_percent": cond_end}))
|
||||
return c
|
||||
|
||||
def pillow(fn, arg):
|
||||
prev_value = None
|
||||
try:
|
||||
|
||||
@@ -10,6 +10,25 @@ homepage = "https://www.comfy.org/"
|
||||
repository = "https://github.com/comfyanonymous/ComfyUI"
|
||||
documentation = "https://docs.comfy.org/"
|
||||
|
||||
[tool.comfy.host]
|
||||
allow_network = false
|
||||
writable_paths = ["/dev/shm", "/tmp"]
|
||||
|
||||
[tool.comfy.host.whitelist]
|
||||
"ComfyUI-Crystools" = "*"
|
||||
"ComfyUI-Florence2" = "*"
|
||||
"ComfyUI-GGUF" = "*"
|
||||
"ComfyUI-KJNodes" = "*"
|
||||
"ComfyUI-LTXVideo" = "*"
|
||||
"ComfyUI-Manager" = "*"
|
||||
"comfyui-depthanythingv2" = "*"
|
||||
"comfyui-kjnodes" = "*"
|
||||
"comfyui-videohelpersuite" = "*"
|
||||
"comfyui_controlnet_aux" = "*"
|
||||
"rgthree-comfy" = "*"
|
||||
"was-ns" = "*"
|
||||
"websocket_image_save.py" = "*"
|
||||
|
||||
[tool.ruff]
|
||||
lint.select = [
|
||||
"N805", # invalid-first-argument-name-for-method
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.39.19
|
||||
comfyui-workflow-templates==0.9.4
|
||||
comfyui-workflow-templates==0.9.5
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
torchsde
|
||||
@@ -22,7 +22,7 @@ alembic
|
||||
SQLAlchemy
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.7
|
||||
comfy-aimdo>=0.2.2
|
||||
comfy-aimdo>=0.2.4
|
||||
requests
|
||||
|
||||
#non essential dependencies:
|
||||
|
||||
@@ -49,6 +49,12 @@ def mock_provider(mock_releases):
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
import utils.install_util
|
||||
utils.install_util.PACKAGE_VERSIONS = {}
|
||||
|
||||
|
||||
def test_get_release(mock_provider, mock_releases):
|
||||
version = "1.0.0"
|
||||
release = mock_provider.get_release(version)
|
||||
|
||||
@@ -38,13 +38,13 @@ class TestIsPreviewable:
|
||||
"""Unit tests for is_previewable()"""
|
||||
|
||||
def test_previewable_media_types(self):
|
||||
"""Images, video, audio, 3d media types should be previewable."""
|
||||
for media_type in ['images', 'video', 'audio', '3d']:
|
||||
"""Images, video, audio, 3d, text media types should be previewable."""
|
||||
for media_type in ['images', 'video', 'audio', '3d', 'text']:
|
||||
assert is_previewable(media_type, {}) is True
|
||||
|
||||
def test_non_previewable_media_types(self):
|
||||
"""Other media types should not be previewable."""
|
||||
for media_type in ['latents', 'text', 'metadata', 'files']:
|
||||
for media_type in ['latents', 'metadata', 'files']:
|
||||
assert is_previewable(media_type, {}) is False
|
||||
|
||||
def test_3d_extensions_previewable(self):
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import logging
|
||||
import re
|
||||
|
||||
# The path to the requirements.txt file
|
||||
requirements_path = Path(__file__).parents[1] / "requirements.txt"
|
||||
@@ -16,3 +18,34 @@ Please install the updated requirements.txt file by running:
|
||||
{sys.executable} {extra}-m pip install -r {requirements_path}
|
||||
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem.
|
||||
""".strip()
|
||||
|
||||
|
||||
def is_valid_version(version: str) -> bool:
|
||||
"""Validate if a string is a valid semantic version (X.Y.Z format)."""
|
||||
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
|
||||
return bool(re.match(pattern, version))
|
||||
|
||||
|
||||
PACKAGE_VERSIONS = {}
|
||||
def get_required_packages_versions():
|
||||
if len(PACKAGE_VERSIONS) > 0:
|
||||
return PACKAGE_VERSIONS.copy()
|
||||
out = PACKAGE_VERSIONS
|
||||
try:
|
||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip().replace(">=", "==")
|
||||
s = line.split("==")
|
||||
if len(s) == 2:
|
||||
version_str = s[-1]
|
||||
if not is_valid_version(version_str):
|
||||
logging.error(f"Invalid version format in requirements.txt: {version_str}")
|
||||
continue
|
||||
out[s[0]] = version_str
|
||||
return out.copy()
|
||||
except FileNotFoundError:
|
||||
logging.error("requirements.txt not found.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading requirements.txt: {e}")
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user