Compare commits

..

49 Commits

Author SHA1 Message Date
Jedrzej Kosinski
3ddcc438dd Fixed redundant comment 2026-02-12 23:03:46 -08:00
Jedrzej Kosinski
0132eb0889 Slight fix to a type annotation 2026-02-12 22:46:04 -08:00
Jedrzej Kosinski
aac40e4fca Skip replacements that lead to a missing node_id anyway 2026-02-12 22:29:15 -08:00
Jedrzej Kosinski
5b4ba294a1 Merge branch 'master' into jk/node-replace-api 2026-02-12 22:25:03 -08:00
Jedrzej Kosinski
23338b5382 Move node replacement registration for core nodes into nodes_replacements.py 2026-02-12 22:23:26 -08:00
Jedrzej Kosinski
14184a0918 Add apply_replacements to NodeReplaceManager to apply registered node replacements when executing /prompt endpoint 2026-02-12 22:04:52 -08:00
comfyanonymous
e03fe8b591 Update command to install AMD stable linux pytorch. (#12437) 2026-02-12 23:29:12 -05:00
Jedrzej Kosinski
e50ab676ae Merge branch 'master' into jk/node-replace-api 2026-02-12 19:49:54 -08:00
rattus
ae79e33345 llama: use a more efficient rope implementation (#12434)
Get rid of the cat and unary negation and inplace add-cmul the two
halves of the rope. Precompute -sin once at the start of the model
rather than every transformer block.

This is slightly faster on both GPU and CPU bound setups.
2026-02-12 19:56:42 -05:00
rattus
117e214354 ModelPatcherDynamic: force load non leaf weights (#12433)
The current behaviour of the default ModelPatcher is to .to a model
only if its fully loaded, which is how random non-leaf weights get
loaded in non-LowVRAM conditions.

The however means they never get loaded in dynamic_vram. In the
dynamic_vram case, force load them to the GPU.
2026-02-12 19:51:50 -05:00
Jin Yi
e2f7eaff26 Merge branch 'master' into jk/node-replace-api 2026-02-12 18:40:47 +09:00
Alexander Piskun
4a93a62371 fix(api-nodes): add separate retry budget for 429 rate limit responses (#12421) 2026-02-12 01:38:51 -08:00
comfyanonymous
66c18522fb Add a tip for common error. (#12414) 2026-02-11 22:12:16 -05:00
askmyteapot
e5ae670a40 Update ace15.py to allow min_p sampling (#12373) 2026-02-11 20:28:48 -05:00
rattus
3fe61cedda model_patcher: guard against none model_dtype (#12410)
Handle the case where the _model_dtype exists but is none with the
intended fallback.
2026-02-11 14:54:02 -05:00
rattus
2a4328d639 ace15: Use dynamic_vram friendly trange (#12409)
Factor out the ksampler trange and use it in ACE LLM to prevent the
silent stall at 0 and rate distortion due to first-step model load.
2026-02-11 14:53:42 -05:00
rattus
d297a749a2 dynamic_vram: Fix windows Aimdo crash + Fix LLM performance (#12408)
* model_management: lazy-cache aimdo_tensor

These tensors cosntructed from aimdo-allocations are CPU expensive to
make on the pytorch side. Add a cache version that will be valid with
signature match to fast path past whatever torch is doing.

* dynamic_vram: Minimize fast path CPU work

Move as much as possible inside the not resident if block and cache
the formed weight and bias rather than the flat intermediates. In
extreme layer weight rates this adds up.
2026-02-11 14:50:16 -05:00
Alexander Piskun
2b7cc7e3b6 [API Nodes] enable Magnific Upscalers (#12179)
* feat(api-nodes): enable Magnific Upscalers

* update price badges

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-11 11:30:19 -08:00
Benjamin Lu
4993411fd9 Dispatch desktop auto-bump when a ComfyUI release is published (#12398)
* Dispatch desktop auto-bump on ComfyUI release publish

* Fix release webhook secret checks in step conditions

* Require desktop dispatch token in release webhook

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Luke Mino-Altherr <lminoaltherr@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-11 11:15:13 -08:00
Alexander Piskun
2c7cef4a23 fix(api-nodes): retry on connection errors during polling instead of aborting (#12393) 2026-02-11 10:51:49 -08:00
Jedrzej Kosinski
b1c69ed6f6 Improve NodeReplace docstring 2026-02-11 01:16:39 -08:00
Jedrzej Kosinski
1ef4c6e529 Merge branch 'master' into jk/node-replace-api 2026-02-11 01:06:10 -08:00
Jedrzej Kosinski
9e758b5b0c Refactored _node_replace.py InputMap/OutputMap to use a TypedDict instead of objects, simplified the schema sent to the frontend, updated nodes_post_processing.py replacements to use new schema 2026-02-11 01:02:55 -08:00
comfyanonymous
76a7fa96db Make built in lora training work on anima. (#12402) 2026-02-10 22:04:32 -05:00
Kohaku-Blueleaf
cdcf4119b3 [Trainer] training with proper offloading (#12189)
* Fix bypass dtype/device moving

* Force offloading mode for training

* training context var

* offloading implementation in training node

* fix wrong input type

* Support bypass load lora model, correct adapter/offloading handling
2026-02-10 21:45:19 -05:00
Jedrzej Kosinski
a6d691dc45 Merge branch 'master' into jk/node-replace-api 2026-02-10 16:54:35 -06:00
Christian Byrne
8d0da49499 feat: add node_replacements server feature flag (#12362)
Amp-Thread-ID: https://ampcode.com/threads/T-019c3f3d-e208-704f-bf25-4f643c1e0059
2026-02-10 14:53:28 -08:00
AustinMroz
dbe70b6821 Add a VideoSlice node (#12107)
* Base TrimVideo implementation

* Raise error if as_trimmed call fails

* Bigger max start_time, tooltips, and formatting

* Count packets unless codec has subframes

* Remove incorrect nested decode

* Add null check for audio streams

* Support non-strict duration

* Added strict_duration bool to node definition

* Empty commit for approval

* Fix duration

* Support 5.1 audio layout on save

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-10 14:42:21 -08:00
guill
00fff6019e feat(jobs): add 3d to PREVIEWABLE_MEDIA_TYPES for first-class 3D output support (#12381)
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-10 14:37:14 -08:00
rattus
123a7874a9 ops: Fix vanilla-fp8 loaded lora quality (#12390)
This was missing the stochastic rounding required for fp8 downcast
to be consistent with model_patcher.patch_weight_to_device.

Missed in testing as I spend too much time with quantized tensors
and overlooked the simpler ones.
2026-02-10 13:38:28 -05:00
rattus
f719f9c062 sd: delay VAE dtype archive until after override (#12388)
VAEs have host specific dtype logic that should override the dynamic
_model_dtype. Defer the archiving of model dtypes until after.
2026-02-10 13:37:46 -05:00
rattus
fe053ba5eb mp: dont deep-clone objects from model_options (#12382)
If there are non-trivial python objects nested in the model_options, this
causes all sorts of issues. Traverse lists and dicts so clones can safely
overide settings and BYO objects but stop there on the deepclone.
2026-02-10 13:37:17 -05:00
comfyanonymous
6648ab68bc ComfyUI v0.13.0 2026-02-10 13:26:29 -05:00
ComfyUI Wiki
6615db925c chore: update workflow templates to v0.8.38 (#12394) 2026-02-10 13:24:56 -05:00
Alexander Piskun
8ca842a8ed feat(api-nodes-Kling): add new models (V3, O3) (#12389)
* feat(api-nodes-Kling): add new models (V3, O3)

* remove storyboard from VideoToVideo node

* added check for total duration of storyboards

* fixed other small things

* updated display name for nodes

* added "fake" seed
2026-02-10 09:34:54 -08:00
Alexander Piskun
c1b63a7e78 fix(Moonvalley-API-Nodes): adjust "steps" parameter to not raise exception (#12370) 2026-02-09 21:58:27 -05:00
ComfyUI Wiki
349a636a2b chore: update workflow templates to v0.8.37 (#12377) 2026-02-09 21:25:34 -05:00
comfyanonymous
a4be04c5d7 Ace step prompts match now. (#12376) 2026-02-09 19:45:56 -05:00
blepping
baf8c87455 Iimprovements to ACE-Steps 1.5 text encoding (part 2) (#12350) 2026-02-09 19:41:49 -05:00
rattus
62315fbb15 Dynamic VRAM fixes - Ace 1.5 performance + a VRAM leak (#12368)
* revert threaded model loader change

This change was only needed to get around the pytorch 2.7 mempool bugs,
and should have been reverted along with #12260. This fixes a different
memory leak where pytorch gets confused about cache emptying.

* load non comfy weights

* MPDynamic: Pre-generate the tensors for vbars

Apparently this is an expensive operation that slows down things.

* bump to aimdo 1.8

New features:
watermark limit feature
logging enhancements
-O2 build on linux
2026-02-09 16:16:08 -05:00
bymyself
739ed21714 fix: use direct PromptServer registration instead of ComfyAPI class
Amp-Thread-ID: https://ampcode.com/threads/T-019c2be8-0b34-747e-b1f7-20a1a1e6c9df
2026-02-05 15:52:21 -08:00
Christian Byrne
a2d4c0f98b refactor: process isolation support for node replacement API (#12298)
* refactor: process isolation support for node replacement API

- Move REGISTERED_NODE_REPLACEMENTS global to NodeReplaceManager instance state
- Add NodeReplacement class to ComfyAPI_latest with async register() method
- Deprecate module-level register_node_replacement() function
- Call register_replacements() from comfy_entrypoint()

This enables pyisolate compatibility where extensions run in separate
processes and communicate via RPC. The async API allows registration
calls to cross process boundaries.

Refs: TDD-002
Amp-Thread-ID: https://ampcode.com/threads/T-019c2b33-ac55-76a9-9c6b-0246a8625f21

* fix: remove whitespace and deprecation cruft

Amp-Thread-ID: https://ampcode.com/threads/T-019c2be8-0b34-747e-b1f7-20a1a1e6c9df
2026-02-05 12:21:03 -08:00
Jin Yi
d5b3da823d feat: add legacy node replacements from frontend hardcoded patches (#12241) 2026-02-04 19:41:23 -08:00
Jedrzej Kosinski
8bbd8f7d65 Fix test ndoe replacement for resize_type.multiplier field 2026-02-04 19:41:23 -08:00
Jedrzej Kosinski
d6b217a7f8 Create some test replacements for frontend testing purposes 2026-02-04 19:41:23 -08:00
Jedrzej Kosinski
04f89c75d1 Rename UseValue to SetValue 2026-02-04 19:41:23 -08:00
Jedrzej Kosinski
588bc6b257 Added old_widget_ids param to NodeReplace 2026-02-04 19:41:23 -08:00
Jedrzej Kosinski
c9dbe13c0c Add public register_node_replacement function to node_replace, add NodeReplaceManager + GET /api/node_replacements 2026-02-04 19:41:23 -08:00
Jedrzej Kosinski
7024486e37 Create helper classes for node replace registration 2026-02-04 19:41:23 -08:00
52 changed files with 2254 additions and 1128 deletions

View File

@@ -7,6 +7,8 @@ on:
jobs:
send-webhook:
runs-on: ubuntu-latest
env:
DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }}
steps:
- name: Send release webhook
env:
@@ -106,3 +108,37 @@ jobs:
--fail --silent --show-error
echo "✅ Release webhook sent successfully"
- name: Send repository dispatch to desktop
env:
DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }}
RELEASE_TAG: ${{ github.event.release.tag_name }}
RELEASE_URL: ${{ github.event.release.html_url }}
run: |
set -euo pipefail
if [ -z "${DISPATCH_TOKEN:-}" ]; then
echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set."
exit 1
fi
PAYLOAD="$(jq -n \
--arg release_tag "$RELEASE_TAG" \
--arg release_url "$RELEASE_URL" \
'{
event_type: "comfyui_release_published",
client_payload: {
release_tag: $release_tag,
release_url: $release_url
}
}')"
curl -fsSL \
-X POST \
-H "Accept: application/vnd.github+json" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
https://api.github.com/repos/Comfy-Org/desktop/dispatches \
-d "$PAYLOAD"
echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop"

View File

@@ -227,7 +227,7 @@ Put your VAE in: models/vae
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:

105
app/node_replace_manager.py Normal file
View File

@@ -0,0 +1,105 @@
from __future__ import annotations
from aiohttp import web
from typing import TYPE_CHECKING, TypedDict
if TYPE_CHECKING:
from comfy_api.latest._node_replace import NodeReplace
from nodes import NODE_CLASS_MAPPINGS
class NodeStruct(TypedDict):
inputs: dict[str, str | int | float | bool | tuple[str, int]]
class_type: str
_meta: dict[str, str]
def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
new_node_struct = node_struct.copy()
if empty_inputs:
new_node_struct["inputs"] = {}
else:
new_node_struct["inputs"] = node_struct["inputs"].copy()
new_node_struct["_meta"] = node_struct["_meta"].copy()
return new_node_struct
class NodeReplaceManager:
"""Manages node replacement registrations."""
def __init__(self):
self._replacements: dict[str, list[NodeReplace]] = {}
def register(self, node_replace: NodeReplace):
"""Register a node replacement mapping."""
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
"""Get replacements for an old node ID."""
return self._replacements.get(old_node_id)
def has_replacement(self, old_node_id: str) -> bool:
"""Check if a replacement exists for an old node ID."""
return old_node_id in self._replacements
def apply_replacements(self, prompt: dict[str, NodeStruct]):
connections: dict[str, list[tuple[str, str, int]]] = {}
need_replacement: set[str] = set()
for node_number, node_struct in prompt.items():
class_type = node_struct["class_type"]
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
if class_type not in NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
need_replacement.add(node_number)
# keep track of connections
for input_id, input_value in node_struct["inputs"].items():
if isinstance(input_value, list):
conn_number = input_value[0]
connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
if len(need_replacement) > 0:
for node_number in need_replacement:
node_struct = prompt[node_number]
class_type = node_struct["class_type"]
replacements = self.get_replacement(class_type)
if replacements is None:
continue
# just use the first replacement
replacement = replacements[0]
new_node_id = replacement.new_node_id
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
if new_node_id not in NODE_CLASS_MAPPINGS.keys():
continue
# first, replace node id (class_type)
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
new_node_struct["class_type"] = new_node_id
# TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
# second, replace inputs
if replacement.input_mapping is not None:
for input_map in replacement.input_mapping:
if "set_value" in input_map:
new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
elif "old_id" in input_map:
new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
# finalize input replacement
prompt[node_number] = new_node_struct
# third, replace outputs
if replacement.output_mapping is not None:
# re-mapping outputs requires changing the input values of nodes that receive connections from this one
if node_number in connections:
for conns in connections[node_number]:
conn_node_number, conn_input_id, old_output_idx = conns
for output_map in replacement.output_mapping:
if output_map["old_idx"] == old_output_idx:
new_output_idx = output_map["new_idx"]
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
previous_input[1] = new_output_idx
def as_dict(self):
"""Serialize all replacements to dict."""
return {
k: [v.as_dict() for v in v_list]
for k, v_list in self._replacements.items()
}
def add_routes(self, routes):
@routes.get("/node_replacements")
async def get_node_replacements(request):
return web.json_response(self.as_dict())

View File

@@ -1,12 +1,11 @@
import math
import time
from functools import partial
from scipy import integrate
import torch
from torch import nn
import torchsde
from tqdm.auto import trange as trange_, tqdm
from tqdm.auto import tqdm
from . import utils
from . import deis
@@ -15,34 +14,7 @@ import comfy.model_patcher
import comfy.model_sampling
import comfy.memory_management
def trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator is None:
return trange_(*args, **kwargs)
pbar = trange_(*args, **kwargs, smoothing=1.0)
pbar._i = 0
pbar.set_postfix_str(" Model Initializing ... ")
_update = pbar.update
def warmup_update(n=1):
pbar._i += 1
if pbar._i == 1:
pbar.i1_time = time.time()
pbar.set_postfix_str(" Model Initialization complete! ")
elif pbar._i == 2:
#bring forward the effective start time based the the diff between first and second iteration
#to attempt to remove load overhead from the final step rate estimate.
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
pbar.set_postfix_str("")
_update(n)
pbar.update = warmup_update
return pbar
from comfy.utils import model_trange as trange
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])

View File

@@ -195,8 +195,20 @@ class Anima(MiniTrainDIT):
super().__init__(*args, **kwargs)
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
def preprocess_text_embeds(self, text_embeds, text_ids):
def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
if text_ids is not None:
return self.llm_adapter(text_embeds, text_ids)
out = self.llm_adapter(text_embeds, text_ids)
if t5xxl_weights is not None:
out = out * t5xxl_weights
if out.shape[1] < 512:
out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
return out
else:
return text_embeds
def forward(self, x, timesteps, context, **kwargs):
t5xxl_ids = kwargs.pop("t5xxl_ids", None)
if t5xxl_ids is not None:
context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
return super().forward(x, timesteps, context, **kwargs)

View File

@@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
return out.to(dtype=torch.float32, device=pos.device)
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
return x_out.reshape(*x.shape).type_as(x)
def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
try:
import comfy.quant_ops
apply_rope = comfy.quant_ops.ck.apply_rope
apply_rope1 = comfy.quant_ops.ck.apply_rope1
q_apply_rope = comfy.quant_ops.ck.apply_rope
q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
def apply_rope(xq, xk, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope(xq, xk, freqs_cis)
else:
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
def apply_rope1(x, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope1(x, freqs_cis)
else:
return q_apply_rope1(x, freqs_cis)
except:
logging.warning("No comfy kitchen, using old apply_rope functions.")
def apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
return x_out.reshape(*x.shape).type_as(x)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
apply_rope = _apply_rope
apply_rope1 = _apply_rope1

View File

@@ -1160,12 +1160,16 @@ class Anima(BaseModel):
device = kwargs["device"]
if cross_attn is not None:
if t5xxl_ids is not None:
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
if t5xxl_weights is not None:
cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
t5xxl_weights = t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
t5xxl_ids = t5xxl_ids.unsqueeze(0)
if torch.is_inference_mode_enabled(): # if not we are training
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
else:
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
if cross_attn.shape[1] < 512:
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out

View File

@@ -19,7 +19,7 @@
import psutil
import logging
from enum import Enum
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
from comfy.cli_args import args, PerformanceFeature
import threading
import torch
import sys
@@ -55,6 +55,11 @@ cpu_state = CPUState.GPU
total_vram = 0
# Training Related State
in_training = False
def get_supported_float8_types():
float8_types = []
try:
@@ -651,7 +656,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
soft_empty_cache()
return unloaded_models
def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state
@@ -747,26 +752,6 @@ def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, m
current_loaded_models.insert(0, loaded_model)
return
def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
with torch.inference_mode():
load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
soft_empty_cache()
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
#Deliberately load models outside of the Aimdo mempool so they can be retained accross
#nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
#thread local. So exploit that to escape context
if enables_dynamic_vram():
t = threading.Thread(
target=load_models_gpu_thread,
args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
)
t.start()
t.join()
else:
load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
def load_model_gpu(model):
return load_models_gpu([model])
@@ -1226,21 +1211,20 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
if dtype is None:
dtype = weight._model_dtype
r = torch.empty_like(weight, dtype=dtype, device=device)
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
v_tensor = weight._v_tensor
else:
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
weight._v_tensor = v_tensor
weight._v_signature = signature
#Send it over
v_tensor.copy_(weight, non_blocking=non_blocking)
#always take a deep copy even if _v is good, as we have no reasonable point to unpin
#a non comfy weight
r.copy_(v_tensor)
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
return r
return v_tensor.to(dtype=dtype)
r = torch.empty_like(weight, dtype=dtype, device=device)
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
#Offloaded casting could skip this, however it would make the quantizations

View File

@@ -19,7 +19,6 @@
from __future__ import annotations
import collections
import copy
import inspect
import logging
import math
@@ -317,7 +316,7 @@ class ModelPatcher:
n.object_patches = self.object_patches.copy()
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self
@@ -680,18 +679,19 @@ class ModelPatcher:
for key in list(self.pinned):
self.unpin_weight(key)
def _load_list(self, prio_comfy_cast_weights=False):
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
loading = []
for n, m in self.model.named_modules():
params = []
skip = False
for name, param in m.named_parameters(recurse=False):
params.append(name)
default = False
params = { name: param for name, param in m.named_parameters(recurse=False) }
for name, param in m.named_parameters(recurse=True):
if name not in params:
skip = True # skip random weights in non leaf modules
default = True # default random weights in non leaf modules
break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
if default and default_device is not None:
for param in params.values():
param.data = param.data.to(device=default_device)
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
@@ -1492,9 +1492,11 @@ class ModelPatcherDynamic(ModelPatcher):
if vbar is not None:
vbar.prioritize()
#We have way more tools for acceleration on comfy weight offloading, so always
#We force reserve VRAM for the non comfy-weight so we dont have to deal
#with pin and unpin syncrhonization which can be expensive for small weights
#with a high layer rate (e.g. autoregressive LLMs).
#prioritize the non-comfy weights (note the order reverse).
loading = self._load_list(prio_comfy_cast_weights=True)
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
loading.sort(reverse=True)
for x in loading:
@@ -1524,7 +1526,7 @@ class ModelPatcherDynamic(ModelPatcher):
setattr(m, param_key + "_function", weight_function)
geometry = weight
if not isinstance(weight, QuantizedTensor):
model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype)
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
weight._model_dtype = model_dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
return comfy.memory_management.vram_aligned_size(geometry)
@@ -1550,13 +1552,14 @@ class ModelPatcherDynamic(ModelPatcher):
weight.seed_key = key
set_dirty(weight, dirty)
geometry = weight
model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype)
model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
weight_size = geometry.numel() * geometry.element_size()
if vbar is not None and not hasattr(weight, "_v"):
weight._v = vbar.alloc(weight_size)
weight._model_dtype = model_dtype
allocated_size += weight_size
vbar.set_watermark_limit(allocated_size)
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
@@ -1577,7 +1580,7 @@ class ModelPatcherDynamic(ModelPatcher):
return 0 if vbar is None else vbar.free_memory(memory_to_free)
def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(prio_comfy_cast_weights=True)
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
for x in loading:
_, _, _, _, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)

View File

@@ -83,14 +83,18 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
offload_stream = None
xfer_dest = None
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
if signature is not None:
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if signature is not None:
if resident:
weight = s._v_weight
bias = s._v_bias
else:
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
if not resident:
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
cast_dest = None
xfer_source = [ s.weight, s.bias ]
@@ -140,9 +144,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
post_cast.copy_(pre_cast)
xfer_dest = cast_dest
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
weight = params[0]
bias = params[1]
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
weight = params[0]
bias = params[1]
if signature is not None:
s._v_weight = weight
s._v_bias = bias
s._v_signature=signature
def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
@@ -169,8 +177,8 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
if orig.dtype == dtype and len(fns) == 0:
#The layer actually wants our freshly saved QT
x = y
else:
y = x
elif update_weight:
y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key))
if update_weight:
orig.copy_(y)
for f in fns:
@@ -182,7 +190,6 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
if s.bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
s._v_signature=signature
#FIXME: weird offload return protocol
return weight, bias, (offload_stream, device if signature is not None else None, None)

View File

@@ -122,20 +122,26 @@ def estimate_memory(model, noise_shape, conds):
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
)
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load
memory_required = 1e20
minimum_memory_required = None
else:
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
memory_required += inference_memory
minimum_memory_required += inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
real_model = model.model
return real_model, conds, models

View File

@@ -793,8 +793,6 @@ class VAE:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
model_management.archive_model_dtypes(self.first_stage_model)
if device is None:
device = model_management.vae_device()
self.device = device
@@ -803,6 +801,7 @@ class VAE:
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
self.vae_dtype = dtype
self.first_stage_model.to(self.vae_dtype)
model_management.archive_model_dtypes(self.first_stage_model)
self.output_device = model_management.intermediate_device()
mp = comfy.model_patcher.CoreModelPatcher

View File

@@ -16,6 +16,7 @@ def sample_manual_loop_no_classes(
temperature: float = 0.85,
top_p: float = 0.9,
top_k: int = None,
min_p: float = 0.000,
seed: int = 1,
min_tokens: int = 1,
max_new_tokens: int = 2048,
@@ -23,6 +24,8 @@ def sample_manual_loop_no_classes(
audio_end_id: int = 215669,
eos_token_id: int = 151645,
):
if ids is None:
return []
device = model.execution_device
if execution_dtype is None:
@@ -32,6 +35,7 @@ def sample_manual_loop_no_classes(
execution_dtype = torch.float32
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
embeds_batch = embeds.shape[0]
for i, t in enumerate(paddings):
attention_mask[i, :t] = 0
attention_mask[i, t:] = 1
@@ -41,22 +45,27 @@ def sample_manual_loop_no_classes(
generator = torch.Generator(device=device)
generator.manual_seed(seed)
model_config = model.transformer.model.config
past_kv_shape = [embeds_batch, model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim]
for x in range(model_config.num_hidden_layers):
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0))
past_key_values.append((torch.empty(past_kv_shape, device=device, dtype=execution_dtype), torch.empty(past_kv_shape, device=device, dtype=execution_dtype), 0))
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
for step in range(max_new_tokens):
for step in comfy.utils.model_trange(max_new_tokens, desc="LM sampling"):
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
past_key_values = outputs[2]
cond_logits = next_token_logits[0:1]
uncond_logits = next_token_logits[1:2]
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
if cfg_scale != 1.0:
cond_logits = next_token_logits[0:1]
uncond_logits = next_token_logits[1:2]
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
else:
cfg_logits = next_token_logits[0:1]
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
use_eos_score = eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step
if use_eos_score:
eos_score = cfg_logits[:, eos_token_id].clone()
remove_logit_value = torch.finfo(cfg_logits.dtype).min
@@ -64,7 +73,7 @@ def sample_manual_loop_no_classes(
cfg_logits[:, :audio_start_id] = remove_logit_value
cfg_logits[:, audio_end_id:] = remove_logit_value
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
if use_eos_score:
cfg_logits[:, eos_token_id] = eos_score
if top_k is not None and top_k > 0:
@@ -72,6 +81,12 @@ def sample_manual_loop_no_classes(
min_val = top_k_vals[..., -1, None]
cfg_logits[cfg_logits < min_val] = remove_logit_value
if min_p is not None and min_p > 0:
probs = torch.softmax(cfg_logits, dim=-1)
p_max = probs.max(dim=-1, keepdim=True).values
indices_to_remove = probs < (min_p * p_max)
cfg_logits[indices_to_remove] = remove_logit_value
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
@@ -93,8 +108,8 @@ def sample_manual_loop_no_classes(
break
embed, _, _, _ = model.process_tokens([[token]], device)
embeds = embed.repeat(2, 1, 1)
attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
embeds = embed.repeat(embeds_batch, 1, 1)
attention_mask = torch.cat([attention_mask, torch.ones((embeds_batch, 1), device=device, dtype=attention_mask.dtype)], dim=1)
output_audio_codes.append(token - audio_start_id)
progress_bar.update_absolute(step)
@@ -102,24 +117,31 @@ def sample_manual_loop_no_classes(
return output_audio_codes
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0, min_p=0.000):
positive = [[token for token, _ in inner_list] for inner_list in positive]
negative = [[token for token, _ in inner_list] for inner_list in negative]
positive = positive[0]
negative = negative[0]
neg_pad = 0
if len(negative) < len(positive):
neg_pad = (len(positive) - len(negative))
negative = [model.special_tokens["pad"]] * neg_pad + negative
if cfg_scale != 1.0:
negative = [[token for token, _ in inner_list] for inner_list in negative]
negative = negative[0]
pos_pad = 0
if len(negative) > len(positive):
pos_pad = (len(negative) - len(positive))
positive = [model.special_tokens["pad"]] * pos_pad + positive
neg_pad = 0
if len(negative) < len(positive):
neg_pad = (len(positive) - len(negative))
negative = [model.special_tokens["pad"]] * neg_pad + negative
paddings = [pos_pad, neg_pad]
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
pos_pad = 0
if len(negative) > len(positive):
pos_pad = (len(negative) - len(positive))
positive = [model.special_tokens["pad"]] * pos_pad + positive
paddings = [pos_pad, neg_pad]
ids = [positive, negative]
else:
paddings = []
ids = [positive]
return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
@@ -129,12 +151,12 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str:
user_metas = {
k: kwargs.pop(k)
for k in ("bpm", "duration", "keyscale", "timesignature", "language", "caption")
for k in ("bpm", "duration", "keyscale", "timesignature")
if k in kwargs
}
timesignature = user_metas.get("timesignature")
if isinstance(timesignature, str) and timesignature.endswith("/4"):
user_metas["timesignature"] = timesignature.rsplit("/", 1)[0]
user_metas["timesignature"] = timesignature[:-2]
user_metas = {
k: v if not isinstance(v, str) or not v.isdigit() else int(v)
for k, v in user_metas.items()
@@ -147,8 +169,11 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
return f"<think>\n{meta_yaml}\n</think>" if not return_yaml else meta_yaml
def _metas_to_cap(self, **kwargs) -> str:
use_keys = ("bpm", "duration", "keyscale", "timesignature")
use_keys = ("bpm", "timesignature", "keyscale", "duration")
user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys }
timesignature = user_metas.get("timesignature")
if isinstance(timesignature, str) and timesignature.endswith("/4"):
user_metas["timesignature"] = timesignature[:-2]
duration = user_metas["duration"]
if duration == "N/A":
user_metas["duration"] = "30 seconds"
@@ -159,9 +184,13 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys)
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
out = {}
text = text.strip()
text_negative = kwargs.get("caption_negative", text).strip()
lyrics = kwargs.get("lyrics", "")
lyrics_negative = kwargs.get("lyrics_negative", lyrics)
duration = kwargs.get("duration", 120)
if isinstance(duration, str):
duration = float(duration.split(None, 1)[0])
language = kwargs.get("language")
seed = kwargs.get("seed", 0)
@@ -170,28 +199,55 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
temperature = kwargs.get("temperature", 0.85)
top_p = kwargs.get("top_p", 0.9)
top_k = kwargs.get("top_k", 0.0)
min_p = kwargs.get("min_p", 0.000)
duration = math.ceil(duration)
kwargs["duration"] = duration
tokens_duration = duration * 5
min_tokens = int(kwargs.get("min_tokens", tokens_duration))
max_tokens = int(kwargs.get("max_tokens", tokens_duration))
cot_text = self._metas_to_cot(caption = text, **kwargs)
metas_negative = {
k.rsplit("_", 1)[0]: kwargs.pop(k)
for k in ("bpm_negative", "duration_negative", "keyscale_negative", "timesignature_negative", "language_negative", "caption_negative")
if k in kwargs
}
if not kwargs.get("use_negative_caption"):
_ = metas_negative.pop("caption", None)
cot_text = self._metas_to_cot(caption=text, **kwargs)
cot_text_negative = "<think>\n\n</think>" if not metas_negative else self._metas_to_cot(**metas_negative)
meta_cap = self._metas_to_cap(**kwargs)
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n<|im_end|>\n"
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n\n<|im_end|>\n"
lyrics_template = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>"
qwen3_06b_template = "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>"
out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, cot_text), disable_weights=True)
out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, "<think>\n</think>"), disable_weights=True)
llm_prompts = {
"lm_prompt": lm_template.format(text, lyrics.strip(), cot_text),
"lm_prompt_negative": lm_template.format(text_negative, lyrics_negative.strip(), cot_text_negative),
"lyrics": lyrics_template.format(language if language is not None else "", lyrics),
"qwen3_06b": qwen3_06b_template.format(text, meta_cap),
}
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>".format(language if language is not None else "", lyrics), return_word_ids, disable_weights=True, **kwargs)
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
out["lm_metadata"] = {"min_tokens": duration * 5,
out = {
prompt_key: self.qwen3_06b.tokenize_with_weights(
prompt,
prompt_key == "qwen3_06b" and return_word_ids,
disable_weights = True,
**kwargs,
)
for prompt_key, prompt in llm_prompts.items()
}
out["lm_metadata"] = {"min_tokens": min_tokens,
"max_tokens": max_tokens,
"seed": seed,
"generate_audio_codes": generate_audio_codes,
"cfg_scale": cfg_scale,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"min_p": min_p,
}
return out
@@ -252,7 +308,7 @@ class ACE15TEModel(torch.nn.Module):
lm_metadata = token_weight_pairs["lm_metadata"]
if lm_metadata["generate_audio_codes"]:
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"], min_p=lm_metadata["min_p"])
out["audio_codes"] = [audio_codes]
return base_out, None, out

View File

@@ -355,13 +355,6 @@ class RMSNorm(nn.Module):
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
if not isinstance(theta, list):
theta = [theta]
@@ -390,20 +383,30 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
out.append((cos, sin))
sin_split = sin.shape[-1] // 2
out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))
if len(out) == 1:
return out[0]
return out
def apply_rope(xq, xk, freqs_cis):
org_dtype = xq.dtype
cos = freqs_cis[0]
sin = freqs_cis[1]
q_embed = (xq * cos) + (rotate_half(xq) * sin)
k_embed = (xk * cos) + (rotate_half(xk) * sin)
nsin = freqs_cis[2]
q_embed = (xq * cos)
q_split = q_embed.shape[-1] // 2
q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin)
q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin)
k_embed = (xk * cos)
k_split = k_embed.shape[-1] // 2
k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin)
k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin)
return q_embed.to(org_dtype), k_embed.to(org_dtype)

View File

@@ -27,6 +27,7 @@ from PIL import Image
import logging
import itertools
from torch.nn.functional import interpolate
from tqdm.auto import trange
from einops import rearrange
from comfy.cli_args import args, enables_dynamic_vram
import json
@@ -1155,6 +1156,32 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
def model_trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator is None:
return trange(*args, **kwargs)
pbar = trange(*args, **kwargs, smoothing=1.0)
pbar._i = 0
pbar.set_postfix_str(" Model Initializing ... ")
_update = pbar.update
def warmup_update(n=1):
pbar._i += 1
if pbar._i == 1:
pbar.i1_time = time.time()
pbar.set_postfix_str(" Model Initialization complete! ")
elif pbar._i == 2:
#bring forward the effective start time based the the diff between first and second iteration
#to attempt to remove load overhead from the final step rate estimate.
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
pbar.set_postfix_str("")
_update(n)
pbar.update = warmup_update
return pbar
PROGRESS_BAR_ENABLED = True
def set_progress_bar_enabled(enabled):
global PROGRESS_BAR_ENABLED
@@ -1376,3 +1403,21 @@ def string_to_seed(data):
else:
crc >>= 1
return crc ^ 0xFFFFFFFF
def deepcopy_list_dict(obj, memo=None):
if memo is None:
memo = {}
obj_id = id(obj)
if obj_id in memo:
return memo[obj_id]
if isinstance(obj, dict):
res = {deepcopy_list_dict(k, memo): deepcopy_list_dict(v, memo) for k, v in obj.items()}
elif isinstance(obj, list):
res = [deepcopy_list_dict(i, memo) for i in obj]
else:
res = obj
memo[obj_id] = res
return res

View File

@@ -21,6 +21,7 @@ from typing import Optional, Union
import torch
import torch.nn as nn
import comfy.model_management
from .base import WeightAdapterBase, WeightAdapterTrainBase
from comfy.patcher_extension import PatcherInjection
@@ -181,18 +182,21 @@ class BypassForwardHook:
)
return # Already injected
# Move adapter weights to module's device to avoid CPU-GPU transfer on every forward
device = None
# Move adapter weights to compute device (GPU)
# Use get_torch_device() instead of module.weight.device because
# with offloading, module weights may be on CPU while compute happens on GPU
device = comfy.model_management.get_torch_device()
# Get dtype from module weight if available
dtype = None
if hasattr(self.module, "weight") and self.module.weight is not None:
device = self.module.weight.device
dtype = self.module.weight.dtype
elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
device = self.module.W_q.device
dtype = self.module.W_q.dtype
if device is not None:
self._move_adapter_weights_to_device(device, dtype)
# Only use dtype if it's a standard float type, not quantized
if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16):
dtype = None
self._move_adapter_weights_to_device(device, dtype)
self.original_forward = self.module.forward
self.module.forward = self._bypass_forward

View File

@@ -14,6 +14,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
"node_replacements": True,
}

View File

@@ -10,6 +10,7 @@ from ._input_impl import VideoFromFile, VideoFromComponents
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D
from . import _io_public as io
from . import _ui_public as ui
from . import _node_replace_public as node_replace
from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple
from PIL import Image
@@ -21,6 +22,14 @@ class ComfyAPI_latest(ComfyAPIBase):
VERSION = "latest"
STABLE = False
class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: 'node_replace.NodeReplace') -> None:
"""Register a node replacement mapping."""
from server import PromptServer
PromptServer.instance.node_replace_manager.register(node_replace)
node_replacement: NodeReplacement
class Execution(ProxiedSingleton):
async def set_progress(
self,
@@ -131,4 +140,5 @@ __all__ = [
"IO",
"ui",
"UI",
"node_replace",
]

View File

@@ -34,6 +34,21 @@ class VideoInput(ABC):
"""
pass
@abstractmethod
def as_trimmed(
self,
start_time: float | None = None,
duration: float | None = None,
strict_duration: bool = False,
) -> VideoInput | None:
"""
Create a new VideoInput which is trimmed to have the corresponding start_time and duration
Returns:
A new VideoInput, or None if the result would have negative duration
"""
pass
def get_stream_source(self) -> Union[str, io.BytesIO]:
"""
Get a streamable source for the video. This allows processing without

View File

@@ -6,6 +6,7 @@ from typing import Optional
from .._input import AudioInput, VideoInput
import av
import io
import itertools
import json
import numpy as np
import math
@@ -29,7 +30,6 @@ def container_to_output_format(container_format: str | None) -> str | None:
formats = container_format.split(",")
return formats[0]
def get_open_write_kwargs(
dest: str | io.BytesIO, container_format: str, to_format: str | None
) -> dict:
@@ -57,12 +57,14 @@ class VideoFromFile(VideoInput):
Class representing video input from a file.
"""
def __init__(self, file: str | io.BytesIO):
def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0):
"""
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
containing the file contents.
"""
self.__file = file
self.__start_time = start_time
self.__duration = duration
def get_stream_source(self) -> str | io.BytesIO:
"""
@@ -96,6 +98,16 @@ class VideoFromFile(VideoInput):
Returns:
Duration in seconds
"""
raw_duration = self._get_raw_duration()
if self.__start_time < 0:
duration_from_start = min(raw_duration, -self.__start_time)
else:
duration_from_start = raw_duration - self.__start_time
if self.__duration:
return min(self.__duration, duration_from_start)
return duration_from_start
def _get_raw_duration(self) -> float:
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
@@ -113,9 +125,13 @@ class VideoFromFile(VideoInput):
if video_stream and video_stream.average_rate:
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
frame_iterator = (
container.decode(video_stream)
if video_stream.codec.capabilities & 0x100
else container.demux(video_stream)
)
for packet in frame_iterator:
frame_count += 1
if frame_count > 0:
return float(frame_count / video_stream.average_rate)
@@ -131,36 +147,54 @@ class VideoFromFile(VideoInput):
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
# 1. Prefer the frames field if available
if video_stream.frames and video_stream.frames > 0:
# 1. Prefer the frames field if available and usable
if (
video_stream.frames
and video_stream.frames > 0
and not self.__start_time
and not self.__duration
):
return int(video_stream.frames)
# 2. Try to estimate from duration and average_rate using only metadata
if container.duration is not None and video_stream.average_rate:
duration_seconds = float(container.duration / av.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
if (
getattr(video_stream, "duration", None) is not None
and getattr(video_stream, "time_base", None) is not None
and video_stream.average_rate
):
duration_seconds = float(video_stream.duration * video_stream.time_base)
raw_duration = float(video_stream.duration * video_stream.time_base)
if self.__start_time < 0:
duration_from_start = min(raw_duration, -self.__start_time)
else:
duration_from_start = raw_duration - self.__start_time
duration_seconds = min(self.__duration, duration_from_start)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
# 3. Last resort: decode frames and count them (streaming)
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count == 0:
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
if self.__start_time < 0:
start_time = max(self._get_raw_duration() + self.__start_time, 0)
else:
start_time = self.__start_time
frame_count = 1
start_pts = int(start_time / video_stream.time_base)
end_pts = int((start_time + self.__duration) / video_stream.time_base)
container.seek(start_pts, stream=video_stream)
frame_iterator = (
container.decode(video_stream)
if video_stream.codec.capabilities & 0x100
else container.demux(video_stream)
)
for frame in frame_iterator:
if frame.pts >= start_pts:
break
else:
raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}")
for frame in frame_iterator:
if frame.pts >= end_pts:
break
frame_count += 1
return frame_count
def get_frame_rate(self) -> Fraction:
@@ -199,9 +233,21 @@ class VideoFromFile(VideoInput):
return container.format.name
def get_components_internal(self, container: InputContainer) -> VideoComponents:
video_stream = self._get_first_video_stream(container)
if self.__start_time < 0:
start_time = max(self._get_raw_duration() + self.__start_time, 0)
else:
start_time = self.__start_time
# Get video frames
frames = []
for frame in container.decode(video=0):
start_pts = int(start_time / video_stream.time_base)
end_pts = int((start_time + self.__duration) / video_stream.time_base)
container.seek(start_pts, stream=video_stream)
for frame in container.decode(video_stream):
if frame.pts < start_pts:
continue
if self.__duration and frame.pts >= end_pts:
break
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
frames.append(img)
@@ -209,31 +255,44 @@ class VideoFromFile(VideoInput):
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
# Get frame rate
video_stream = next(s for s in container.streams if s.type == 'video')
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
# Get audio if available
audio = None
try:
container.seek(0) # Reset the container to the beginning
for stream in container.streams:
if stream.type != 'audio':
continue
assert isinstance(stream, av.AudioStream)
audio_frames = []
for packet in container.demux(stream):
for frame in packet.decode():
assert isinstance(frame, av.AudioFrame)
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
audio = AudioInput({
"waveform": audio_tensor,
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
})
except StopIteration:
pass # No audio stream
container.seek(start_pts, stream=video_stream)
# Use last stream for consistency
if len(container.streams.audio):
audio_stream = container.streams.audio[-1]
audio_frames = []
resample = av.audio.resampler.AudioResampler(format='fltp').resample
frames = itertools.chain.from_iterable(
map(resample, container.decode(audio_stream))
)
has_first_frame = False
for frame in frames:
offset_seconds = start_time - frame.pts * audio_stream.time_base
to_skip = int(offset_seconds * audio_stream.sample_rate)
if to_skip < frame.samples:
has_first_frame = True
break
if has_first_frame:
audio_frames.append(frame.to_ndarray()[..., to_skip:])
for frame in frames:
if frame.time > start_time + self.__duration:
break
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
if self.__duration:
audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
audio = AudioInput({
"waveform": audio_tensor,
"sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
})
metadata = container.metadata
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
@@ -250,7 +309,7 @@ class VideoFromFile(VideoInput):
path: str | io.BytesIO,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
metadata: Optional[dict] = None,
):
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
@@ -262,15 +321,14 @@ class VideoFromFile(VideoInput):
reuse_streams = False
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
reuse_streams = False
if self.__start_time or self.__duration:
reuse_streams = False
if not reuse_streams:
components = self.get_components_internal(container)
video = VideoFromComponents(components)
return video.save_to(
path,
format=format,
codec=codec,
metadata=metadata
path, format=format, codec=codec, metadata=metadata
)
streams = container.streams
@@ -304,10 +362,21 @@ class VideoFromFile(VideoInput):
output_container.mux(packet)
def _get_first_video_stream(self, container: InputContainer):
video_stream = next((s for s in container.streams if s.type == "video"), None)
if video_stream is None:
raise ValueError(f"No video stream found in file '{self.__file}'")
return video_stream
if len(container.streams.video):
return container.streams.video[0]
raise ValueError(f"No video stream found in file '{self.__file}'")
def as_trimmed(
self, start_time: float = 0, duration: float = 0, strict_duration: bool = True
) -> VideoInput | None:
trimmed = VideoFromFile(
self.get_stream_source(),
start_time=start_time + self.__start_time,
duration=duration,
)
if trimmed.get_duration() < duration and strict_duration:
return None
return trimmed
class VideoFromComponents(VideoInput):
@@ -322,7 +391,7 @@ class VideoFromComponents(VideoInput):
return VideoComponents(
images=self.__components.images,
audio=self.__components.audio,
frame_rate=self.__components.frame_rate
frame_rate=self.__components.frame_rate,
)
def save_to(
@@ -330,7 +399,7 @@ class VideoFromComponents(VideoInput):
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
metadata: Optional[dict] = None,
):
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
@@ -357,7 +426,10 @@ class VideoFromComponents(VideoInput):
audio_stream: Optional[av.AudioStream] = None
if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate'])
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
waveform = self.__components.audio['waveform']
waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo')
audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout)
# Encode video
for i, frame in enumerate(self.__components.images):
@@ -372,12 +444,21 @@ class VideoFromComponents(VideoInput):
output.mux(packet)
if audio_stream and self.__components.audio:
waveform = self.__components.audio['waveform']
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().numpy(), format='fltp', layout=layout)
frame.sample_rate = audio_sample_rate
frame.pts = 0
output.mux(audio_stream.encode(frame))
# Flush encoder
output.mux(audio_stream.encode(None))
def as_trimmed(
self,
start_time: float | None = None,
duration: float | None = None,
strict_duration: bool = True,
) -> VideoInput | None:
if self.get_duration() < start_time + duration:
return None
#TODO Consider tracking duration and trimming at time of save?
return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration)

View File

@@ -1430,11 +1430,6 @@ class Schema:
"""Flags a node as expandable, allowing NodeOutput to include 'expand' property."""
accept_all_inputs: bool=False
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
lazy_outputs: bool=False
"""When True, cache will invalidate when output connections change, and expected_outputs will be available.
Use this for nodes that can skip computing outputs that aren't connected downstream.
Access via `get_executing_context().expected_outputs` - outputs NOT in the set are definitely unused."""
def validate(self):
'''Validate the schema:
@@ -1880,14 +1875,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls.GET_SCHEMA()
return cls._ACCEPT_ALL_INPUTS
_LAZY_OUTPUTS = None
@final
@classproperty
def LAZY_OUTPUTS(cls): # noqa
if cls._LAZY_OUTPUTS is None:
cls.GET_SCHEMA()
return cls._LAZY_OUTPUTS
@final
@classmethod
def INPUT_TYPES(cls) -> dict[str, dict]:
@@ -1930,8 +1917,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls._NOT_IDEMPOTENT = schema.not_idempotent
if cls._ACCEPT_ALL_INPUTS is None:
cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs
if cls._LAZY_OUTPUTS is None:
cls._LAZY_OUTPUTS = schema.lazy_outputs
if cls._RETURN_TYPES is None:
output = []

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
from typing import Any, TypedDict
class InputMapOldId(TypedDict):
"""Map an old node input to a new node input by ID."""
new_id: str
old_id: str
class InputMapSetValue(TypedDict):
"""Set a specific value for a new node input."""
new_id: str
set_value: Any
InputMap = InputMapOldId | InputMapSetValue
"""
Input mapping for node replacement. Type is inferred by dictionary keys:
- {"new_id": str, "old_id": str} - maps old input to new input
- {"new_id": str, "set_value": Any} - sets a specific value for new input
"""
class OutputMap(TypedDict):
"""Map outputs of node replacement via indexes."""
new_idx: int
old_idx: int
class NodeReplace:
"""
Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
Also supports assigning specific values to the input widgets of the new node.
Args:
new_node_id: The class name of the new replacement node.
old_node_id: The class name of the deprecated node.
old_widget_ids: Ordered list of input IDs for widgets that may not have an input slot
connected. The workflow JSON stores widget values by their relative position index,
not by ID. This list maps those positional indexes to input IDs, enabling the
replacement system to correctly identify widget values during node migration.
input_mapping: List of input mappings from old node to new node.
output_mapping: List of output mappings from old node to new node.
"""
def __init__(self,
new_node_id: str,
old_node_id: str,
old_widget_ids: list[str] | None=None,
input_mapping: list[InputMap] | None=None,
output_mapping: list[OutputMap] | None=None,
):
self.new_node_id = new_node_id
self.old_node_id = old_node_id
self.old_widget_ids = old_widget_ids
self.input_mapping = input_mapping
self.output_mapping = output_mapping
def as_dict(self):
"""Create serializable representation of the node replacement."""
return {
"new_node_id": self.new_node_id,
"old_node_id": self.old_node_id,
"old_widget_ids": self.old_widget_ids,
"input_mapping": list(self.input_mapping) if self.input_mapping else None,
"output_mapping": list(self.output_mapping) if self.output_mapping else None,
}

View File

@@ -0,0 +1 @@
from ._node_replace import * # noqa: F403

View File

@@ -6,7 +6,7 @@ from comfy_api.latest import (
)
from typing import Type, TYPE_CHECKING
from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401
from comfy_api.latest import io, ui, IO, UI, ComfyExtension, node_replace #noqa: F401
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
@@ -46,4 +46,5 @@ __all__ = [
"IO",
"ui",
"UI",
"node_replace",
]

View File

@@ -1197,12 +1197,6 @@ class KlingImageGenImageReferenceType(str, Enum):
face = 'face'
class KlingImageGenModelName(str, Enum):
kling_v1 = 'kling-v1'
kling_v1_5 = 'kling-v1-5'
kling_v2 = 'kling-v2'
class KlingImageGenerationsRequest(BaseModel):
aspect_ratio: Optional[KlingImageGenAspectRatio] = '16:9'
callback_url: Optional[AnyUrl] = Field(
@@ -1218,7 +1212,7 @@ class KlingImageGenerationsRequest(BaseModel):
0.5, description='Reference intensity for user-uploaded images', ge=0.0, le=1.0
)
image_reference: Optional[KlingImageGenImageReferenceType] = None
model_name: Optional[KlingImageGenModelName] = 'kling-v1'
model_name: str = Field(...)
n: Optional[int] = Field(1, description='Number of generated images', ge=1, le=9)
negative_prompt: Optional[str] = Field(
None, description='Negative text prompt', max_length=200

View File

@@ -1,12 +1,22 @@
from pydantic import BaseModel, Field
class MultiPromptEntry(BaseModel):
index: int = Field(...)
prompt: str = Field(...)
duration: str = Field(...)
class OmniProText2VideoRequest(BaseModel):
model_name: str = Field(..., description="kling-video-o1")
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
multi_shot: bool | None = Field(None)
multi_prompt: list[MultiPromptEntry] | None = Field(None)
shot_type: str | None = Field(None)
sound: str = Field(..., description="'on' or 'off'")
class OmniParamImage(BaseModel):
@@ -26,6 +36,10 @@ class OmniProFirstLastFrameRequest(BaseModel):
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str | None = Field(None, description="'on' or 'off'")
multi_shot: bool | None = Field(None)
multi_prompt: list[MultiPromptEntry] | None = Field(None)
shot_type: str | None = Field(None)
class OmniProReferences2VideoRequest(BaseModel):
@@ -38,6 +52,10 @@ class OmniProReferences2VideoRequest(BaseModel):
duration: str | None = Field(..., description="From 3 to 10.")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str | None = Field(None, description="'on' or 'off'")
multi_shot: bool | None = Field(None)
multi_prompt: list[MultiPromptEntry] | None = Field(None)
shot_type: str | None = Field(None)
class TaskStatusVideoResult(BaseModel):
@@ -54,6 +72,7 @@ class TaskStatusImageResult(BaseModel):
class TaskStatusResults(BaseModel):
videos: list[TaskStatusVideoResult] | None = Field(None)
images: list[TaskStatusImageResult] | None = Field(None)
series_images: list[TaskStatusImageResult] | None = Field(None)
class TaskStatusResponseData(BaseModel):
@@ -77,31 +96,42 @@ class OmniImageParamImage(BaseModel):
class OmniProImageRequest(BaseModel):
model_name: str = Field(..., description="kling-image-o1")
resolution: str = Field(..., description="'1k' or '2k'")
model_name: str = Field(...)
resolution: str = Field(...)
aspect_ratio: str | None = Field(...)
prompt: str = Field(...)
mode: str = Field("pro")
n: int | None = Field(1, le=9)
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
result_type: str | None = Field(None, description="Set to 'series' for series generation")
series_amount: int | None = Field(None, ge=2, le=9, description="Number of images in a series")
class TextToVideoWithAudioRequest(BaseModel):
model_name: str = Field(..., description="kling-v2-6")
model_name: str = Field(...)
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
duration: str = Field(...)
prompt: str | None = Field(...)
negative_prompt: str | None = Field(None)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
multi_shot: bool | None = Field(None)
multi_prompt: list[MultiPromptEntry] | None = Field(None)
shot_type: str | None = Field(None)
class ImageToVideoWithAudioRequest(BaseModel):
model_name: str = Field(..., description="kling-v2-6")
model_name: str = Field(...)
image: str = Field(...)
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
image_tail: str | None = Field(None)
duration: str = Field(...)
prompt: str | None = Field(...)
negative_prompt: str | None = Field(None)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
multi_shot: bool | None = Field(None)
multi_prompt: list[MultiPromptEntry] | None = Field(None)
shot_type: str | None = Field(None)
class MotionControlRequest(BaseModel):

File diff suppressed because it is too large Load Diff

View File

@@ -30,6 +30,30 @@ from comfy_api_nodes.util import (
validate_image_dimensions,
)
_EUR_TO_USD = 1.19
def _tier_price_eur(megapixels: float) -> float:
"""Price in EUR for a single Magnific upscaling step based on input megapixels."""
if megapixels <= 1.3:
return 0.143
if megapixels <= 3.0:
return 0.286
if megapixels <= 6.4:
return 0.429
return 1.716
def _calculate_magnific_upscale_price_usd(width: int, height: int, scale: int) -> float:
"""Calculate total Magnific upscale price in USD for given input dimensions and scale factor."""
num_steps = int(math.log2(scale))
total_eur = 0.0
pixels = width * height
for _ in range(num_steps):
total_eur += _tier_price_eur(pixels / 1_000_000)
pixels *= 4
return round(total_eur * _EUR_TO_USD, 2)
class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
@classmethod
@@ -103,11 +127,20 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor", "auto_downscale"]),
expr="""
(
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
$ad := widgets.auto_downscale;
$mins := $ad
? {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.515}
: {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
$maxs := {"2x": 0.515, "4x": 0.844, "8x": 1.015, "16x": 1.187};
{
"type": "range_usd",
"min_usd": $lookup($mins, widgets.scale_factor),
"max_usd": $lookup($maxs, widgets.scale_factor),
"format": { "approximate": true }
}
)
""",
),
@@ -168,6 +201,10 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
f"Use a smaller input image or lower scale factor."
)
final_height, final_width = get_image_dimensions(image)
actual_scale = int(scale_factor.rstrip("x"))
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, actual_scale)
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"),
@@ -189,6 +226,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"),
response_model=TaskResponse,
status_extractor=lambda x: x.status,
price_extractor=lambda _: price_usd,
poll_interval=10.0,
max_poll_attempts=480,
)
@@ -257,8 +295,14 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
expr="""
(
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
$mins := {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
$maxs := {"2x": 2.045, "4x": 2.545, "8x": 2.889, "16x": 3.06};
{
"type": "range_usd",
"min_usd": $lookup($mins, widgets.scale_factor),
"max_usd": $lookup($maxs, widgets.scale_factor),
"format": { "approximate": true }
}
)
""",
),
@@ -321,6 +365,9 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
f"Use a smaller input image or lower scale factor."
)
final_height, final_width = get_image_dimensions(image)
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, requested_scale)
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"),
@@ -339,6 +386,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"),
response_model=TaskResponse,
status_extractor=lambda x: x.status,
price_extractor=lambda _: price_usd,
poll_interval=10.0,
max_poll_attempts=480,
)
@@ -877,8 +925,8 @@ class MagnificExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
# MagnificImageUpscalerCreativeNode,
# MagnificImageUpscalerPreciseV2Node,
MagnificImageUpscalerCreativeNode,
MagnificImageUpscalerPreciseV2Node,
MagnificImageStyleTransferNode,
MagnificImageRelightNode,
MagnificImageSkinEnhancerNode,

View File

@@ -219,8 +219,8 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
),
IO.Int.Input(
"steps",
default=33,
min=1,
default=80,
min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
max=100,
step=1,
tooltip="Number of denoising steps",
@@ -340,8 +340,8 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
),
IO.Int.Input(
"steps",
default=33,
min=1,
default=60,
min=60, # steps should be greater or equal to cooldown_steps(36) + warmup_steps(24)
max=100,
step=1,
display_mode=IO.NumberDisplay.number,
@@ -370,7 +370,7 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
video: Input.Video | None = None,
control_type: str = "Motion Transfer",
motion_intensity: int | None = 100,
steps=33,
steps=60,
prompt_adherence=4.5,
) -> IO.NodeOutput:
validated_video = validate_video_to_video_input(video)
@@ -465,8 +465,8 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode):
),
IO.Int.Input(
"steps",
default=33,
min=1,
default=80,
min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
max=100,
step=1,
tooltip="Inference steps",

View File

@@ -57,6 +57,7 @@ class _RequestConfig:
files: dict[str, Any] | list[tuple[str, Any]] | None
multipart_parser: Callable | None
max_retries: int
max_retries_on_rate_limit: int
retry_delay: float
retry_backoff: float
wait_label: str = "Waiting"
@@ -65,6 +66,7 @@ class _RequestConfig:
final_label_on_success: str | None = "Completed"
progress_origin_ts: float | None = None
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
is_rate_limited: Callable[[int, Any], bool] | None = None
@dataclass
@@ -78,7 +80,7 @@ class _PollUIState:
active_since: float | None = None # start time of current active interval (None if queued)
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
@@ -103,6 +105,8 @@ async def sync_op(
final_label_on_success: str | None = "Completed",
progress_origin_ts: float | None = None,
monitor_progress: bool = True,
max_retries_on_rate_limit: int = 16,
is_rate_limited: Callable[[int, Any], bool] | None = None,
) -> M:
raw = await sync_op_raw(
cls,
@@ -122,6 +126,8 @@ async def sync_op(
final_label_on_success=final_label_on_success,
progress_origin_ts=progress_origin_ts,
monitor_progress=monitor_progress,
max_retries_on_rate_limit=max_retries_on_rate_limit,
is_rate_limited=is_rate_limited,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
@@ -143,9 +149,9 @@ async def poll_op(
poll_interval: float = 5.0,
max_poll_attempts: int = 160,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
max_retries_per_poll: int = 10,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0,
retry_backoff_per_poll: float = 1.4,
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
@@ -194,6 +200,8 @@ async def sync_op_raw(
final_label_on_success: str | None = "Completed",
progress_origin_ts: float | None = None,
monitor_progress: bool = True,
max_retries_on_rate_limit: int = 16,
is_rate_limited: Callable[[int, Any], bool] | None = None,
) -> dict[str, Any] | bytes:
"""
Make a single network request.
@@ -222,6 +230,8 @@ async def sync_op_raw(
final_label_on_success=final_label_on_success,
progress_origin_ts=progress_origin_ts,
price_extractor=price_extractor,
max_retries_on_rate_limit=max_retries_on_rate_limit,
is_rate_limited=is_rate_limited,
)
return await _request_base(cfg, expect_binary=as_binary)
@@ -240,9 +250,9 @@ async def poll_op_raw(
poll_interval: float = 5.0,
max_poll_attempts: int = 160,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
max_retries_per_poll: int = 10,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0,
retry_backoff_per_poll: float = 1.4,
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
@@ -506,7 +516,7 @@ def _friendly_http_message(status: int, body: Any) -> str:
if status == 409:
return "There is a problem with your account. Please contact support@comfy.org."
if status == 429:
return "Rate Limit Exceeded: Please try again later."
return "Rate Limit Exceeded: The server returned 429 after all retry attempts. Please wait and try again."
try:
if isinstance(body, dict):
err = body.get("error")
@@ -586,6 +596,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
attempt = 0
delay = cfg.retry_delay
rate_limit_attempts = 0
rate_limit_delay = cfg.retry_delay
operation_succeeded: bool = False
final_elapsed_seconds: int | None = None
extracted_price: float | None = None
@@ -653,17 +665,14 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload_headers["Content-Type"] = "application/json"
payload_kw["json"] = cfg.data or {}
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
)
except Exception as _log_e:
logging.debug("[DEBUG] request logging failed: %s", _log_e)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
)
req_coro = sess.request(method, url, params=params, **payload_kw)
req_task = asyncio.create_task(req_coro)
@@ -688,41 +697,33 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
body = await resp.json()
except (ContentTypeError, json.JSONDecodeError):
body = await resp.text()
if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries:
should_retry = False
wait_time = 0.0
retry_label = ""
is_rl = resp.status == 429 or (
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
)
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
rate_limit_attempts += 1
wait_time = min(rate_limit_delay, 30.0)
rate_limit_delay *= cfg.retry_backoff
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
should_retry = True
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
wait_time = delay
delay *= cfg.retry_backoff
retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
should_retry = True
if should_retry:
logging.warning(
"HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).",
"HTTP %s %s -> %s. Waiting %.2fs (%s).",
method,
url,
resp.status,
delay,
attempt,
cfg.max_retries,
wait_time,
retry_label,
)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
error_message=_friendly_http_message(resp.status, body),
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
await sleep_with_interrupt(
delay,
cfg.node_cls,
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
)
delay *= cfg.retry_backoff
continue
msg = _friendly_http_message(resp.status, body)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
@@ -730,10 +731,27 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
error_message=msg,
error_message=f"HTTP {resp.status} ({retry_label}, will retry in {wait_time:.1f}s)",
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
await sleep_with_interrupt(
wait_time,
cfg.node_cls,
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
)
continue
msg = _friendly_http_message(resp.status, body)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
error_message=msg,
)
raise Exception(msg)
if expect_binary:
@@ -753,17 +771,14 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
bytes_payload = bytes(buff)
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=bytes_payload,
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=bytes_payload,
)
return bytes_payload
else:
try:
@@ -780,45 +795,39 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=response_content_to_log,
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=response_content_to_log,
)
return payload
except ProcessingInterrupted:
logging.debug("Polling was interrupted by user")
raise
except (ClientError, OSError) as e:
if attempt <= cfg.max_retries:
if (attempt - rate_limit_attempts) <= cfg.max_retries:
logging.warning(
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
method,
url,
delay,
attempt,
attempt - rate_limit_attempts,
cfg.max_retries,
str(e),
)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
except Exception as _log_e:
logging.debug("[DEBUG] request error logging failed: %s", _log_e)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
await sleep_with_interrupt(
delay,
cfg.node_cls,
@@ -831,23 +840,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
continue
diag = await _diagnose_connectivity()
if not diag["internet_accessible"]:
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"LocalNetworkError: {str(e)}",
)
except Exception as _log_e:
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
raise LocalNetworkError(
"Unable to connect to the API server due to local network issues. "
"Please check your internet connection and try again."
) from e
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
@@ -855,10 +847,21 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"ApiServerError: {str(e)}",
error_message=f"LocalNetworkError: {str(e)}",
)
except Exception as _log_e:
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
raise LocalNetworkError(
"Unable to connect to the API server due to local network issues. "
"Please check your internet connection and try again."
) from e
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"ApiServerError: {str(e)}",
)
raise ApiServerError(
f"The API server at {default_base_url()} is currently unreachable. "
f"The service may be experiencing issues."

View File

@@ -167,27 +167,25 @@ async def download_url_to_bytesio(
with contextlib.suppress(Exception):
dest.seek(0)
with contextlib.suppress(Exception):
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=f"[streamed {written} bytes to dest]",
)
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=f"[streamed {written} bytes to dest]",
)
return
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
except (ClientError, OSError) as e:
if attempt <= max_retries:
with contextlib.suppress(Exception):
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
request_url=url,
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
request_url=url,
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
await sleep_with_interrupt(delay, cls, None, None, None)
delay *= retry_backoff
continue

View File

@@ -8,7 +8,6 @@ from typing import Any
import folder_paths
# Get the logger instance
logger = logging.getLogger(__name__)
@@ -91,38 +90,41 @@ def log_request_response(
Filenames are sanitized and length-limited for cross-platform safety.
If we still fail to write, we fall back to appending into api.log.
"""
log_dir = get_log_directory()
filepath = _build_log_filepath(log_dir, operation_id, request_url)
log_content: list[str] = []
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
log_content.append(f"Operation ID: {operation_id}")
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
log_content.append(f"Method: {request_method}")
log_content.append(f"URL: {request_url}")
if request_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
if request_params:
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
if request_data is not None:
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
if response_status_code is not None:
log_content.append(f"Status Code: {response_status_code}")
if response_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
if response_content is not None:
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
if error_message:
log_content.append(f"Error:\n{error_message}")
try:
with open(filepath, "w", encoding="utf-8") as f:
f.write("\n".join(log_content))
logger.debug("API log saved to: %s", filepath)
except Exception as e:
logger.error("Error writing API log to %s: %s", filepath, str(e))
log_dir = get_log_directory()
filepath = _build_log_filepath(log_dir, operation_id, request_url)
log_content: list[str] = []
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
log_content.append(f"Operation ID: {operation_id}")
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
log_content.append(f"Method: {request_method}")
log_content.append(f"URL: {request_url}")
if request_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
if request_params:
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
if request_data is not None:
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
if response_status_code is not None:
log_content.append(f"Status Code: {response_status_code}")
if response_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
if response_content is not None:
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
if error_message:
log_content.append(f"Error:\n{error_message}")
try:
with open(filepath, "w", encoding="utf-8") as f:
f.write("\n".join(log_content))
logger.debug("API log saved to: %s", filepath)
except Exception as e:
logger.error("Error writing API log to %s: %s", filepath, str(e))
except Exception as _log_e:
logging.debug("[DEBUG] log_request_response failed: %s", _log_e)
if __name__ == '__main__':

View File

@@ -255,17 +255,14 @@ async def upload_file(
monitor_task = asyncio.create_task(_monitor())
sess: aiohttp.ClientSession | None = None
try:
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers or None,
request_params=None,
request_data=f"[File data {len(data)} bytes]",
)
except Exception as e:
logging.debug("[DEBUG] upload request logging failed: %s", e)
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers or None,
request_params=None,
request_data=f"[File data {len(data)} bytes]",
)
sess = aiohttp.ClientSession(timeout=timeout)
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
@@ -311,31 +308,27 @@ async def upload_file(
delay *= retry_backoff
continue
raise Exception(f"Failed to upload (HTTP {resp.status}).")
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content="File uploaded successfully.",
)
except Exception as e:
logging.debug("[DEBUG] upload response logging failed: %s", e)
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content="File uploaded successfully.",
)
return
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
except (aiohttp.ClientError, OSError) as e:
if attempt <= max_retries:
with contextlib.suppress(Exception):
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers or None,
request_data=f"[File data {len(data)} bytes]",
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers or None,
request_data=f"[File data {len(data)} bytes]",
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
await sleep_with_interrupt(
delay,
cls,

View File

@@ -5,7 +5,7 @@ import psutil
import time
import torch
from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node
from comfy_execution.graph import DynamicPrompt
from abc import ABC, abstractmethod
import nodes
@@ -115,10 +115,6 @@ class CacheKeySetInputSignature(CacheKeySet):
signature = [class_type, await self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id)
# Include expected_outputs in cache key for nodes that opt in via LAZY_OUTPUTS
if hasattr(class_def, 'LAZY_OUTPUTS') and class_def.LAZY_OUTPUTS:
expected = get_expected_outputs_for_node(dynprompt, node_id)
signature.append(("expected_outputs", tuple(sorted(expected))))
inputs = node["inputs"]
for key in sorted(inputs.keys()):
if is_link(inputs[key]):

View File

@@ -19,15 +19,6 @@ class NodeInputError(Exception):
class NodeNotFoundError(Exception):
pass
def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset:
"""Get the set of output indices that are connected downstream.
Returns outputs that MIGHT be used.
Outputs NOT in this set are DEFINITELY not used and safe to skip.
"""
return dynprompt.get_expected_outputs_map().get(node_id, frozenset())
class DynamicPrompt:
def __init__(self, original_prompt):
# The original prompt provided by the user
@@ -36,7 +27,6 @@ class DynamicPrompt:
self.ephemeral_prompt = {}
self.ephemeral_parents = {}
self.ephemeral_display = {}
self._expected_outputs_map = None
def get_node(self, node_id):
if node_id in self.ephemeral_prompt:
@@ -52,7 +42,6 @@ class DynamicPrompt:
self.ephemeral_prompt[node_id] = node_info
self.ephemeral_parents[node_id] = parent_id
self.ephemeral_display[node_id] = display_id
self._expected_outputs_map = None
def get_real_node_id(self, node_id):
while node_id in self.ephemeral_parents:
@@ -70,26 +59,6 @@ class DynamicPrompt:
def all_node_ids(self):
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
def _build_expected_outputs_map(self):
result = {}
for node_id in self.all_node_ids():
try:
node_data = self.get_node(node_id)
except NodeNotFoundError:
continue
for value in node_data.get("inputs", {}).values():
if is_link(value):
from_node_id, from_socket = value
if from_node_id not in result:
result[from_node_id] = set()
result[from_node_id].add(from_socket)
self._expected_outputs_map = {k: frozenset(v) for k, v in result.items()}
def get_expected_outputs_map(self):
if self._expected_outputs_map is None:
self._build_expected_outputs_map()
return self._expected_outputs_map
def get_original_prompt(self):
return self.original_prompt

View File

@@ -20,10 +20,60 @@ class JobStatus:
# Media types that can be previewed in the frontend
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'})
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'})
# 3D file extensions for preview fallback (no dedicated media_type exists)
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'})
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'})
def has_3d_extension(filename: str) -> bool:
lower = filename.lower()
return any(lower.endswith(ext) for ext in THREE_D_EXTENSIONS)
def normalize_output_item(item):
"""Normalize a single output list item for the jobs API.
Returns the normalized item, or None to exclude it.
String items with 3D extensions become {filename, type, subfolder} dicts.
"""
if item is None:
return None
if isinstance(item, str):
if has_3d_extension(item):
return {'filename': item, 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
return None
if isinstance(item, dict):
return item
return None
def normalize_outputs(outputs: dict) -> dict:
"""Normalize raw node outputs for the jobs API.
Transforms string 3D filenames into file output dicts and removes
None items. All other items (non-3D strings, dicts, etc.) are
preserved as-is.
"""
normalized = {}
for node_id, node_outputs in outputs.items():
if not isinstance(node_outputs, dict):
normalized[node_id] = node_outputs
continue
normalized_node = {}
for media_type, items in node_outputs.items():
if media_type == 'animated' or not isinstance(items, list):
normalized_node[media_type] = items
continue
normalized_items = []
for item in items:
if item is None:
continue
norm = normalize_output_item(item)
normalized_items.append(norm if norm is not None else item)
normalized_node[media_type] = normalized_items
normalized[node_id] = normalized_node
return normalized
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
@@ -45,9 +95,9 @@ def is_previewable(media_type: str, item: dict) -> bool:
Maintains backwards compatibility with existing logic.
Priority:
1. media_type is 'images', 'video', or 'audio'
1. media_type is 'images', 'video', 'audio', or '3d'
2. format field starts with 'video/' or 'audio/'
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb)
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb, .usdz)
"""
if media_type in PREVIEWABLE_MEDIA_TYPES:
return True
@@ -139,7 +189,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
})
if include_outputs:
job['outputs'] = outputs
job['outputs'] = normalize_outputs(outputs)
job['execution_status'] = status_info
job['workflow'] = {
'prompt': prompt,
@@ -171,18 +221,23 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
continue
for item in items:
count += 1
if not isinstance(item, dict):
normalized = normalize_output_item(item)
if normalized is None:
continue
if preview_output is None and is_previewable(media_type, item):
count += 1
if preview_output is not None:
continue
if isinstance(normalized, dict) and is_previewable(media_type, normalized):
enriched = {
**item,
**normalized,
'nodeId': node_id,
'mediaType': media_type
}
if item.get('type') == 'output':
if 'mediaType' not in normalized:
enriched['mediaType'] = media_type
if normalized.get('type') == 'output':
preview_output = enriched
elif fallback_preview is None:
fallback_preview = enriched

View File

@@ -1,41 +1,23 @@
import contextvars
from typing import NamedTuple, FrozenSet
from typing import Optional, NamedTuple
class ExecutionContext(NamedTuple):
"""
Context information about the currently executing node.
Attributes:
prompt_id: The ID of the current prompt execution
node_id: The ID of the currently executing node
list_index: The index in a list being processed (for operations on batches/lists)
expected_outputs: Set of output indices that might be used downstream.
Outputs NOT in this set are definitely unused (safe to skip).
None means the information is not available.
"""
prompt_id: str
node_id: str
list_index: int | None
expected_outputs: FrozenSet[int] | None = None
list_index: Optional[int]
current_executing_context: contextvars.ContextVar[ExecutionContext | None] = contextvars.ContextVar("current_executing_context", default=None)
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None)
def get_executing_context() -> ExecutionContext | None:
def get_executing_context() -> Optional[ExecutionContext]:
return current_executing_context.get(None)
def is_output_needed(output_index: int) -> bool:
"""Check if an output at the given index is connected downstream.
Returns True if the output might be used (should be computed).
Returns False if the output is definitely not connected (safe to skip).
"""
ctx = get_executing_context()
if ctx is None or ctx.expected_outputs is None:
return True
return output_index in ctx.expected_outputs
class CurrentNodeContext:
"""
Context manager for setting the current executing node context.
@@ -43,22 +25,15 @@ class CurrentNodeContext:
Sets the current_executing_context on enter and resets it on exit.
Example:
with CurrentNodeContext(prompt_id="abc", node_id="123", list_index=0):
with CurrentNodeContext(node_id="123", list_index=0):
# Code that should run with the current node context set
process_image()
"""
def __init__(
self,
prompt_id: str,
node_id: str,
list_index: int | None = None,
expected_outputs: FrozenSet[int] | None = None,
):
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None):
self.context = ExecutionContext(
prompt_id=prompt_id,
node_id=node_id,
list_index=list_index,
expected_outputs=expected_outputs,
prompt_id= prompt_id,
node_id= node_id,
list_index= list_index
)
self.token = None

View File

@@ -49,13 +49,14 @@ class TextEncodeAceStepAudio15(io.ComfyNode):
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
],
outputs=[io.Conditioning.Output()],
)
@classmethod
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k)
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
conditioning = clip.encode_from_tokens_scheduled(tokens)
return io.NodeOutput(conditioning)

View File

@@ -655,6 +655,7 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
batched = batch_masks(values)
return io.NodeOutput(batched)
class PostProcessingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:

View File

@@ -0,0 +1,103 @@
from comfy_api.latest import ComfyExtension, io, node_replace
from server import PromptServer
def _register(nr: node_replace.NodeReplace):
"""Helper to register replacements via PromptServer."""
PromptServer.instance.node_replace_manager.register(nr)
async def register_replacements():
"""Register all built-in node replacements."""
register_replacements_longeredge()
register_replacements_batchimages()
register_replacements_upscaleimage()
register_replacements_controlnet()
register_replacements_load3d()
register_replacements_preview3d()
register_replacements_svdimg2vid()
register_replacements_conditioningavg()
def register_replacements_longeredge():
# No dynamic inputs here
_register(node_replace.NodeReplace(
new_node_id="ImageScaleToMaxDimension",
old_node_id="ResizeImagesByLongerEdge",
old_widget_ids=["longer_edge"],
input_mapping=[
{"new_id": "image", "old_id": "images"},
{"new_id": "largest_size", "old_id": "longer_edge"},
{"new_id": "upscale_method", "set_value": "lanczos"},
],
# just to test the frontend output_mapping code, does nothing really here
output_mapping=[{"new_idx": 0, "old_idx": 0}],
))
def register_replacements_batchimages():
# BatchImages node uses Autogrow
_register(node_replace.NodeReplace(
new_node_id="BatchImagesNode",
old_node_id="ImageBatch",
input_mapping=[
{"new_id": "images.image0", "old_id": "image1"},
{"new_id": "images.image1", "old_id": "image2"},
],
))
def register_replacements_upscaleimage():
# ResizeImageMaskNode uses DynamicCombo
_register(node_replace.NodeReplace(
new_node_id="ResizeImageMaskNode",
old_node_id="ImageScaleBy",
old_widget_ids=["upscale_method", "scale_by"],
input_mapping=[
{"new_id": "input", "old_id": "image"},
{"new_id": "resize_type", "set_value": "scale by multiplier"},
{"new_id": "resize_type.multiplier", "old_id": "scale_by"},
{"new_id": "scale_method", "old_id": "upscale_method"},
],
))
def register_replacements_controlnet():
# T2IAdapterLoader → ControlNetLoader
_register(node_replace.NodeReplace(
new_node_id="ControlNetLoader",
old_node_id="T2IAdapterLoader",
input_mapping=[
{"new_id": "control_net_name", "old_id": "t2i_adapter_name"},
],
))
def register_replacements_load3d():
# Load3DAnimation merged into Load3D
_register(node_replace.NodeReplace(
new_node_id="Load3D",
old_node_id="Load3DAnimation",
))
def register_replacements_preview3d():
# Preview3DAnimation merged into Preview3D
_register(node_replace.NodeReplace(
new_node_id="Preview3D",
old_node_id="Preview3DAnimation",
))
def register_replacements_svdimg2vid():
# Typo fix: SDV → SVD
_register(node_replace.NodeReplace(
new_node_id="SVD_img2vid_Conditioning",
old_node_id="SDV_img2vid_Conditioning",
))
def register_replacements_conditioningavg():
# Typo fix: trailing space in node name
_register(node_replace.NodeReplace(
new_node_id="ConditioningAverage",
old_node_id="ConditioningAverage ",
))
class NodeReplacementsExtension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return []
async def comfy_entrypoint() -> NodeReplacementsExtension:
await register_replacements()
return NodeReplacementsExtension()

View File

@@ -4,6 +4,7 @@ import os
import numpy as np
import safetensors
import torch
import torch.nn as nn
import torch.utils.checkpoint
from tqdm.auto import trange
from PIL import Image, ImageDraw, ImageFont
@@ -27,6 +28,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
"""
CFGGuider with modifications for training specific logic
"""
def __init__(self, *args, offloading=False, **kwargs):
super().__init__(*args, **kwargs)
self.offloading = offloading
def outer_sample(
self,
noise,
@@ -45,9 +51,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
noise.shape,
self.conds,
self.model_options,
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded
force_full_load=not self.offloading,
force_offload=self.offloading,
)
)
torch.cuda.empty_cache()
device = self.model_patcher.load_device
if denoise_mask is not None:
@@ -404,16 +412,97 @@ def find_all_highest_child_module_with_forward(
return result
def patch(m):
def find_modules_at_depth(
model: nn.Module, depth: int = 1, result=None, current_depth=0, name=None
) -> list[nn.Module]:
"""
Find modules at a specific depth level for gradient checkpointing.
Args:
model: The model to search
depth: Target depth level (1 = top-level blocks, 2 = their children, etc.)
result: Accumulator for results
current_depth: Current recursion depth
name: Current module name for logging
Returns:
List of modules at the target depth
"""
if result is None:
result = []
name = name or "root"
# Skip container modules (they don't have meaningful forward)
is_container = isinstance(model, (nn.ModuleList, nn.Sequential, nn.ModuleDict))
has_forward = hasattr(model, "forward") and not is_container
if has_forward:
current_depth += 1
if current_depth == depth:
result.append(model)
logging.debug(f"Found module at depth {depth}: {name} ({model.__class__.__name__})")
return result
# Recurse into children
for next_name, child in model.named_children():
find_modules_at_depth(child, depth, result, current_depth, f"{name}.{next_name}")
return result
class OffloadCheckpointFunction(torch.autograd.Function):
"""
Gradient checkpointing that works with weight offloading.
Forward: no_grad -> compute -> weights can be freed
Backward: enable_grad -> recompute -> backward -> weights can be freed
For single input, single output modules (Linear, Conv*).
"""
@staticmethod
def forward(ctx, x: torch.Tensor, forward_fn):
ctx.save_for_backward(x)
ctx.forward_fn = forward_fn
with torch.no_grad():
return forward_fn(x)
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
x, = ctx.saved_tensors
forward_fn = ctx.forward_fn
# Clear context early
ctx.forward_fn = None
with torch.enable_grad():
x_detached = x.detach().requires_grad_(True)
y = forward_fn(x_detached)
y.backward(grad_out)
grad_x = x_detached.grad
# Explicit cleanup
del y, x_detached, forward_fn
return grad_x, None
def patch(m, offloading=False):
if not hasattr(m, "forward"):
return
org_forward = m.forward
def fwd(args, kwargs):
return org_forward(*args, **kwargs)
# Branch 1: Linear/Conv* -> offload-compatible checkpoint (single input/output)
if offloading and isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
def checkpointing_fwd(x):
return OffloadCheckpointFunction.apply(x, org_forward)
# Branch 2: Others -> standard checkpoint
else:
def fwd(args, kwargs):
return org_forward(*args, **kwargs)
def checkpointing_fwd(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
def checkpointing_fwd(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
m.org_forward = org_forward
m.forward = checkpointing_fwd
@@ -936,6 +1025,18 @@ class TrainLoraNode(io.ComfyNode):
default=True,
tooltip="Use gradient checkpointing for training.",
),
io.Int.Input(
"checkpoint_depth",
default=1,
min=1,
max=5,
tooltip="Depth level for gradient checkpointing.",
),
io.Boolean.Input(
"offloading",
default=False,
tooltip="Depth level for gradient checkpointing.",
),
io.Combo.Input(
"existing_lora",
options=folder_paths.get_filename_list("loras") + ["[None]"],
@@ -982,6 +1083,8 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype,
algorithm,
gradient_checkpointing,
checkpoint_depth,
offloading,
existing_lora,
bucket_mode,
bypass_mode,
@@ -1000,6 +1103,8 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype = lora_dtype[0]
algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0]
offloading = offloading[0]
checkpoint_depth = checkpoint_depth[0]
existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0]
bypass_mode = bypass_mode[0]
@@ -1054,16 +1159,18 @@ class TrainLoraNode(io.ComfyNode):
# Setup gradient checkpointing
if gradient_checkpointing:
for m in find_all_highest_child_module_with_forward(
mp.model.diffusion_model
):
patch(m)
modules_to_patch = find_modules_at_depth(
mp.model.diffusion_model, depth=checkpoint_depth
)
logging.info(f"Gradient checkpointing: patching {len(modules_to_patch)} modules at depth {checkpoint_depth}")
for m in modules_to_patch:
patch(m, offloading=offloading)
torch.cuda.empty_cache()
# With force_full_load=False we should be able to have offloading
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
comfy.model_management.load_models_gpu(
[mp], memory_required=1e20, force_full_load=True
[mp], memory_required=1e20, force_full_load=not offloading
)
torch.cuda.empty_cache()
@@ -1100,7 +1207,7 @@ class TrainLoraNode(io.ComfyNode):
)
# Setup guider
guider = TrainGuider(mp)
guider = TrainGuider(mp, offloading=offloading)
guider.set_conds(positive)
# Inject bypass hooks if bypass mode is enabled
@@ -1113,6 +1220,7 @@ class TrainLoraNode(io.ComfyNode):
# Run training loop
try:
comfy.model_management.in_training = True
_run_training_loop(
guider,
train_sampler,
@@ -1123,6 +1231,7 @@ class TrainLoraNode(io.ComfyNode):
multi_res,
)
finally:
comfy.model_management.in_training = False
# Eject bypass hooks if they were injected
if bypass_injections is not None:
for injection in bypass_injections:
@@ -1132,19 +1241,20 @@ class TrainLoraNode(io.ComfyNode):
unpatch(m)
del train_sampler, optimizer
# Finalize adapters
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype).detach()
for adapter in all_weight_adapters:
adapter.requires_grad_(False)
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype)
del adapter
del all_weight_adapters
# mp in train node is highly specialized for training
# use it in inference will result in bad behavior so we don't return it
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader(io.ComfyNode):#
class LoraModelLoader(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
@@ -1166,6 +1276,11 @@ class LoraModelLoader(io.ComfyNode):#
max=100.0,
tooltip="How strongly to modify the diffusion model. This value can be negative.",
),
io.Boolean.Input(
"bypass",
default=False,
tooltip="When enabled, applies LoRA in bypass mode without modifying base model weights. Useful for training and when model weights are offloaded.",
),
],
outputs=[
io.Model.Output(
@@ -1175,13 +1290,18 @@ class LoraModelLoader(io.ComfyNode):#
)
@classmethod
def execute(cls, model, lora, strength_model):
def execute(cls, model, lora, strength_model, bypass=False):
if strength_model == 0:
return io.NodeOutput(model)
model_lora, _ = comfy.sd.load_lora_for_models(
model, None, lora, strength_model, 0
)
if bypass:
model_lora, _ = comfy.sd.load_bypass_lora_for_models(
model, None, lora, strength_model, 0
)
else:
model_lora, _ = comfy.sd.load_lora_for_models(
model, None, lora, strength_model, 0
)
return io.NodeOutput(model_lora)

View File

@@ -202,6 +202,56 @@ class LoadVideo(io.ComfyNode):
return True
class VideoSlice(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Video Slice",
display_name="Video Slice",
search_aliases=[
"trim video duration",
"skip first frames",
"frame load cap",
"start time",
],
category="image/video",
inputs=[
io.Video.Input("video"),
io.Float.Input(
"start_time",
default=0.0,
max=1e5,
min=-1e5,
step=0.001,
tooltip="Start time in seconds",
),
io.Float.Input(
"duration",
default=0.0,
min=0.0,
step=0.001,
tooltip="Duration in seconds, or 0 for unlimited duration",
),
io.Boolean.Input(
"strict_duration",
default=False,
tooltip="If True, when the specified duration is not possible, an error will be raised.",
),
],
outputs=[
io.Video.Output(),
],
)
@classmethod
def execute(cls, video: io.Video.Type, start_time: float, duration: float, strict_duration: bool) -> io.NodeOutput:
trimmed = video.as_trimmed(start_time, duration, strict_duration=strict_duration)
if trimmed is not None:
return io.NodeOutput(trimmed)
raise ValueError(
f"Failed to slice video:\nSource duration: {video.get_duration()}\nStart time: {start_time}\nTarget duration: {duration}"
)
class VideoExtension(ComfyExtension):
@override
@@ -212,6 +262,7 @@ class VideoExtension(ComfyExtension):
CreateVideo,
GetVideoComponents,
LoadVideo,
VideoSlice,
]
async def comfy_entrypoint() -> VideoExtension:

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.12.3"
__version__ = "0.13.0"

View File

@@ -13,8 +13,11 @@ from contextlib import nullcontext
import torch
from comfy.cli_args import args
import comfy.memory_management
import comfy.model_management
import comfy_aimdo.model_vbar
from latent_preview import set_preview_method
import nodes
from comfy_execution.caching import (
@@ -31,7 +34,6 @@ from comfy_execution.graph import (
ExecutionBlocker,
ExecutionList,
get_input_info,
get_expected_outputs_for_node,
)
from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input
@@ -228,18 +230,7 @@ async def resolve_map_node_over_list_results(results):
raise exc
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
async def _async_map_node_over_list(
prompt_id,
unique_id,
obj,
input_data_all,
func,
allow_interrupt=False,
execution_block_cb=None,
pre_execute_cb=None,
v3_data=None,
expected_outputs=None,
):
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
# check if node wants the lists
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
@@ -289,12 +280,10 @@ async def _async_map_node_over_list(
else:
f = getattr(obj, func)
if inspect.iscoroutinefunction(f):
async def async_wrapper(f, prompt_id, unique_id, list_index, args, expected_outputs):
with CurrentNodeContext(prompt_id, unique_id, list_index, expected_outputs):
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
with CurrentNodeContext(prompt_id, unique_id, list_index):
return await f(**args)
task = asyncio.create_task(
async_wrapper(f, prompt_id, unique_id, index, args=inputs, expected_outputs=expected_outputs)
)
task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs))
# Give the task a chance to execute without yielding
await asyncio.sleep(0)
if task.done():
@@ -303,7 +292,7 @@ async def _async_map_node_over_list(
else:
results.append(task)
else:
with CurrentNodeContext(prompt_id, unique_id, index, expected_outputs):
with CurrentNodeContext(prompt_id, unique_id, index):
result = f(**inputs)
results.append(result)
else:
@@ -341,17 +330,8 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results])
return output
async def get_output_data(
prompt_id,
unique_id,
obj,
input_data_all,
execution_block_cb=None,
pre_execute_cb=None,
v3_data=None,
expected_outputs=None,
):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
if has_pending_task:
return return_values, {}, False, has_pending_task
@@ -545,14 +525,15 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
#will cause all sorts of incompatible memory shapes to fragment the pytorch alloc
#that we just want to cull out each model run.
allocator = comfy.memory_management.aimdo_allocator
expected_outputs = get_expected_outputs_for_node(dynprompt, unique_id)
with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())):
try:
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
finally:
if allocator is not None:
if args.verbose == "DEBUG":
comfy_aimdo.model_vbar.vbars_analyze()
comfy.model_management.reset_cast_buffers()
torch.cuda.synchronize()
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
@@ -642,6 +623,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models()
elif isinstance(ex, RuntimeError) and ("mat1 and mat2 shapes" in str(ex)) and "Sampler" in class_type:
tips = "\n\nTIPS: If you have any \"Load CLIP\" or \"*CLIP Loader\" nodes in your workflow connected to this sampler node make sure the correct file(s) and type is selected."
error_details = {
"node_id": real_node_id,

View File

@@ -2435,6 +2435,7 @@ async def init_builtin_extra_nodes():
"nodes_lora_debug.py",
"nodes_color.py",
"nodes_toolkit.py",
"nodes_replacements.py",
]
import_failed = []

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.12.3"
version = "0.13.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.38.13
comfyui-workflow-templates==0.8.31
comfyui-workflow-templates==0.8.38
comfyui-embedded-docs==0.4.1
torch
torchsde
@@ -22,7 +22,7 @@ alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.7
comfy-aimdo>=0.1.7
comfy-aimdo>=0.1.8
requests
#non essential dependencies:

View File

@@ -40,6 +40,7 @@ from app.user_manager import UserManager
from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from app.subgraph_manager import SubgraphManager
from app.node_replace_manager import NodeReplaceManager
from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes
from protocol import BinaryEventTypes
@@ -204,6 +205,7 @@ class PromptServer():
self.model_file_manager = ModelFileManager()
self.custom_node_manager = CustomNodeManager()
self.subgraph_manager = SubgraphManager()
self.node_replace_manager = NodeReplaceManager()
self.internal_routes = InternalRoutes(self)
self.supports = ["custom_nodes_from_web"]
self.prompt_queue = execution.PromptQueue(self)
@@ -887,6 +889,8 @@ class PromptServer():
if "partial_execution_targets" in json_data:
partial_execution_targets = json_data["partial_execution_targets"]
self.node_replace_manager.apply_replacements(prompt)
valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
extra_data = {}
if "extra_data" in json_data:
@@ -995,6 +999,7 @@ class PromptServer():
self.model_file_manager.add_routes(self.routes)
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
self.node_replace_manager.add_routes(self.routes)
self.app.add_subapp('/internal', self.internal_routes.get_app())
# Prefix every route with /api for easier matching for delegation.

View File

@@ -1,322 +0,0 @@
"""Unit tests for the expected_outputs feature.
This feature allows nodes to know at runtime which outputs are connected downstream,
enabling them to skip computing outputs that aren't needed.
"""
from comfy_api.latest import IO
from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node
from comfy_execution.utils import (
CurrentNodeContext,
ExecutionContext,
get_executing_context,
is_output_needed,
)
class TestGetExpectedOutputsForNode:
"""Tests for get_expected_outputs_for_node() function."""
def test_single_output_connected(self):
"""Test node with single output connected to one downstream node."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0})
def test_multiple_outputs_partial_connected(self):
"""Test node with multiple outputs, only some connected."""
prompt = {
"1": {"class_type": "MultiOutputNode", "inputs": {}},
"2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}},
# Output 1 is not connected
"3": {"class_type": "ConsumerC", "inputs": {"input": ["1", 2]}},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0, 2})
assert 1 not in expected # Output 1 is definitely unused
def test_no_outputs_connected(self):
"""Test node with no outputs connected."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "OtherNode", "inputs": {}},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset()
def test_same_output_connected_multiple_times(self):
"""Test same output connected to multiple downstream nodes."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}},
"3": {"class_type": "ConsumerB", "inputs": {"input": ["1", 0]}},
"4": {"class_type": "ConsumerC", "inputs": {"input": ["1", 0]}},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0})
def test_node_not_in_prompt(self):
"""Test getting expected outputs for a node not in the prompt."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "999")
assert expected == frozenset()
def test_chained_nodes(self):
"""Test expected outputs in a chain of nodes."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "MiddleNode", "inputs": {"input": ["1", 0]}},
"3": {"class_type": "EndNode", "inputs": {"input": ["2", 0]}},
}
dynprompt = DynamicPrompt(prompt)
# Node 1's output 0 is connected to node 2
expected_1 = get_expected_outputs_for_node(dynprompt, "1")
assert expected_1 == frozenset({0})
# Node 2's output 0 is connected to node 3
expected_2 = get_expected_outputs_for_node(dynprompt, "2")
assert expected_2 == frozenset({0})
# Node 3 has no downstream connections
expected_3 = get_expected_outputs_for_node(dynprompt, "3")
assert expected_3 == frozenset()
def test_complex_graph(self):
"""Test expected outputs in a complex graph with multiple connections."""
prompt = {
"1": {"class_type": "MultiOutputNode", "inputs": {}},
"2": {"class_type": "ProcessorA", "inputs": {"image": ["1", 0], "mask": ["1", 1]}},
"3": {"class_type": "ProcessorB", "inputs": {"data": ["1", 2]}},
"4": {"class_type": "Combiner", "inputs": {"a": ["2", 0], "b": ["3", 0]}},
}
dynprompt = DynamicPrompt(prompt)
# Node 1 has outputs 0, 1, 2 all connected
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0, 1, 2})
def test_constant_inputs_ignored(self):
"""Test that constant (non-link) inputs don't affect expected outputs."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {
"class_type": "ConsumerNode",
"inputs": {
"image": ["1", 0],
"value": 42,
"name": "test",
},
},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0})
def test_ephemeral_node_invalidates_cache(self):
"""Test that adding ephemeral nodes updates expected outputs."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
}
dynprompt = DynamicPrompt(prompt)
# Initially only output 0 is connected
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0})
# Add an ephemeral node that connects to output 1
dynprompt.add_ephemeral_node(
"eph_1",
{"class_type": "EphemeralNode", "inputs": {"data": ["1", 1]}},
parent_id="2",
display_id="2",
)
# Now both outputs 0 and 1 should be expected
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0, 1})
class TestExecutionContext:
"""Tests for ExecutionContext with expected_outputs field."""
def test_context_with_expected_outputs(self):
"""Test creating ExecutionContext with expected_outputs."""
ctx = ExecutionContext(
prompt_id="prompt-123", node_id="node-456", list_index=0, expected_outputs=frozenset({0, 2})
)
assert ctx.prompt_id == "prompt-123"
assert ctx.node_id == "node-456"
assert ctx.list_index == 0
assert ctx.expected_outputs == frozenset({0, 2})
def test_context_without_expected_outputs(self):
"""Test ExecutionContext defaults to None for expected_outputs."""
ctx = ExecutionContext(prompt_id="prompt-123", node_id="node-456", list_index=0)
assert ctx.expected_outputs is None
def test_context_empty_expected_outputs(self):
"""Test ExecutionContext with empty expected_outputs set."""
ctx = ExecutionContext(
prompt_id="prompt-123", node_id="node-456", list_index=None, expected_outputs=frozenset()
)
assert ctx.expected_outputs == frozenset()
assert len(ctx.expected_outputs) == 0
class TestCurrentNodeContext:
"""Tests for CurrentNodeContext context manager with expected_outputs."""
def test_context_manager_with_expected_outputs(self):
"""Test CurrentNodeContext sets and resets context correctly."""
assert get_executing_context() is None
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 1})):
ctx = get_executing_context()
assert ctx is not None
assert ctx.prompt_id == "prompt-1"
assert ctx.node_id == "node-1"
assert ctx.list_index == 0
assert ctx.expected_outputs == frozenset({0, 1})
assert get_executing_context() is None
def test_context_manager_without_expected_outputs(self):
"""Test CurrentNodeContext works without expected_outputs (backwards compatible)."""
with CurrentNodeContext("prompt-1", "node-1"):
ctx = get_executing_context()
assert ctx is not None
assert ctx.expected_outputs is None
def test_nested_context_managers(self):
"""Test nested CurrentNodeContext managers."""
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0})):
ctx1 = get_executing_context()
assert ctx1.expected_outputs == frozenset({0})
with CurrentNodeContext("prompt-1", "node-2", 0, frozenset({1, 2})):
ctx2 = get_executing_context()
assert ctx2.expected_outputs == frozenset({1, 2})
assert ctx2.node_id == "node-2"
# After inner context exits, should be back to outer context
ctx1_again = get_executing_context()
assert ctx1_again.expected_outputs == frozenset({0})
assert ctx1_again.node_id == "node-1"
def test_output_check_pattern(self):
"""Test the typical pattern nodes will use to check expected outputs."""
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
ctx = get_executing_context()
# Typical usage pattern
if ctx and ctx.expected_outputs is not None:
should_compute_0 = 0 in ctx.expected_outputs
should_compute_1 = 1 in ctx.expected_outputs
should_compute_2 = 2 in ctx.expected_outputs
else:
# Fallback when info not available
should_compute_0 = should_compute_1 = should_compute_2 = True
assert should_compute_0 is True
assert should_compute_1 is False # Not in expected_outputs
assert should_compute_2 is True
class TestSchemaLazyOutputs:
"""Tests for lazy_outputs in V3 Schema."""
def test_schema_lazy_outputs_default(self):
"""Test that lazy_outputs defaults to False."""
schema = IO.Schema(
node_id="TestNode",
inputs=[],
outputs=[IO.Float.Output()],
)
assert schema.lazy_outputs is False
def test_schema_lazy_outputs_true(self):
"""Test setting lazy_outputs to True."""
schema = IO.Schema(
node_id="TestNode",
lazy_outputs=True,
inputs=[],
outputs=[IO.Float.Output()],
)
assert schema.lazy_outputs is True
def test_v3_node_lazy_outputs_property(self):
"""Test that LAZY_OUTPUTS property works on V3 nodes."""
class TestNodeWithLazyOutputs(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TestNodeWithLazyOutputs",
lazy_outputs=True,
inputs=[],
outputs=[IO.Float.Output()],
)
@classmethod
def execute(cls):
return IO.NodeOutput(1.0)
assert TestNodeWithLazyOutputs.LAZY_OUTPUTS is True
def test_v3_node_lazy_outputs_default(self):
"""Test that LAZY_OUTPUTS defaults to False on V3 nodes."""
class TestNodeWithoutLazyOutputs(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TestNodeWithoutLazyOutputs",
inputs=[],
outputs=[IO.Float.Output()],
)
@classmethod
def execute(cls):
return IO.NodeOutput(1.0)
assert TestNodeWithoutLazyOutputs.LAZY_OUTPUTS is False
class TestIsOutputNeeded:
"""Tests for is_output_needed() helper function."""
def test_output_needed_when_in_expected(self):
"""Test that output is needed when in expected_outputs."""
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
assert is_output_needed(0) is True
assert is_output_needed(2) is True
def test_output_not_needed_when_not_in_expected(self):
"""Test that output is not needed when not in expected_outputs."""
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
assert is_output_needed(1) is False
assert is_output_needed(3) is False
def test_output_needed_when_no_context(self):
"""Test that output is needed when no context."""
assert get_executing_context() is None
assert is_output_needed(0) is True
assert is_output_needed(1) is True
def test_output_needed_when_expected_outputs_is_none(self):
"""Test that output is needed when expected_outputs is None."""
with CurrentNodeContext("prompt-1", "node-1", 0, None):
assert is_output_needed(0) is True
assert is_output_needed(1) is True

View File

@@ -574,104 +574,6 @@ class TestExecution:
else:
assert result.did_run(test_node), "The execution should have been re-run"
def test_expected_outputs_all_connected(self, client: ComfyClient, builder: GraphBuilder):
"""Test that expected_outputs contains all connected outputs."""
g = builder
# Create a node with 3 outputs, all connected
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
# Connect all 3 outputs to preview nodes
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
output1 = g.node("PreviewImage", images=expected_outputs_node.out(1))
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
result = client.run(g)
# All outputs should be white (255) since all are connected
images0 = result.get_images(output0)
images1 = result.get_images(output1)
images2 = result.get_images(output2)
assert len(images0) == 1, "Should have 1 image for output0"
assert len(images1) == 1, "Should have 1 image for output1"
assert len(images2) == 1, "Should have 1 image for output2"
# White pixels = 255, meaning output was in expected_outputs
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white (was expected)"
assert numpy.array(images1[0]).min() == 255, "Output 1 should be white (was expected)"
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white (was expected)"
def test_expected_outputs_partial_connected(self, client: ComfyClient, builder: GraphBuilder):
"""Test that expected_outputs only contains connected outputs."""
g = builder
# Create a node with 3 outputs, only some connected
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
# Only connect outputs 0 and 2, leave output 1 disconnected
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
# output1 is intentionally not connected
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
result = client.run(g)
# Connected outputs should be white (255)
images0 = result.get_images(output0)
images2 = result.get_images(output2)
assert len(images0) == 1, "Should have 1 image for output0"
assert len(images2) == 1, "Should have 1 image for output2"
# White = expected, output 1 is not connected so we can't verify it directly but outputs 0 and 2 should be white
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white (was expected)"
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white (was expected)"
def test_expected_outputs_single_connected(self, client: ComfyClient, builder: GraphBuilder):
"""Test that expected_outputs works with single connected output."""
g = builder
# Create a node with 3 outputs, only one connected
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
# Only connect output 1
output1 = g.node("PreviewImage", images=expected_outputs_node.out(1))
result = client.run(g)
images1 = result.get_images(output1)
assert len(images1) == 1, "Should have 1 image for output1"
# Output 1 should be white (connected), others are not visible in this test
assert numpy.array(images1[0]).min() == 255, "Output 1 should be white (was expected)"
def test_expected_outputs_cache_invalidation(self, client: ComfyClient, builder: GraphBuilder, server):
"""Test that cache invalidates when output connections change."""
g = builder
# Use unique dimensions to avoid cache collision with other expected_outputs tests
expected_outputs_node = g.node("TestExpectedOutputs", height=32, width=32)
# First run: only connect output 0
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
result1 = client.run(g)
assert result1.did_run(expected_outputs_node), "First run should execute the node"
# Second run: same connections, should be cached
result2 = client.run(g)
if server["should_cache_results"]:
assert not result2.did_run(expected_outputs_node), "Second run should be cached"
# Third run: add connection to output 2
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
result3 = client.run(g)
# Because LAZY_OUTPUTS=True, changing connections should invalidate cache
if server["should_cache_results"]:
assert result3.did_run(expected_outputs_node), "Adding output connection should invalidate cache"
# Verify both outputs are now white
images0 = result3.get_images(output0)
images2 = result3.get_images(output2)
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white"
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white"
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
# Warmup execution to ensure server is fully initialized

View File

@@ -5,8 +5,11 @@ from comfy_execution.jobs import (
is_previewable,
normalize_queue_item,
normalize_history_item,
normalize_output_item,
normalize_outputs,
get_outputs_summary,
apply_sorting,
has_3d_extension,
)
@@ -35,8 +38,8 @@ class TestIsPreviewable:
"""Unit tests for is_previewable()"""
def test_previewable_media_types(self):
"""Images, video, audio media types should be previewable."""
for media_type in ['images', 'video', 'audio']:
"""Images, video, audio, 3d media types should be previewable."""
for media_type in ['images', 'video', 'audio', '3d']:
assert is_previewable(media_type, {}) is True
def test_non_previewable_media_types(self):
@@ -46,7 +49,7 @@ class TestIsPreviewable:
def test_3d_extensions_previewable(self):
"""3D file extensions should be previewable regardless of media_type."""
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
item = {'filename': f'model{ext}'}
assert is_previewable('files', item) is True
@@ -160,7 +163,7 @@ class TestGetOutputsSummary:
def test_3d_files_previewable(self):
"""3D file extensions should be previewable."""
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
outputs = {
'node1': {
'files': [{'filename': f'model{ext}', 'type': 'output'}]
@@ -192,6 +195,64 @@ class TestGetOutputsSummary:
assert preview['mediaType'] == 'images'
assert preview['subfolder'] == 'outputs'
def test_string_3d_filename_creates_preview(self):
"""String items with 3D extensions should synthesize a preview (Preview3D node output).
Only the .glb counts — nulls and non-file strings are excluded."""
outputs = {
'node1': {
'result': ['preview3d_abc123.glb', None, None]
}
}
count, preview = get_outputs_summary(outputs)
assert count == 1
assert preview is not None
assert preview['filename'] == 'preview3d_abc123.glb'
assert preview['mediaType'] == '3d'
assert preview['nodeId'] == 'node1'
assert preview['type'] == 'output'
def test_string_non_3d_filename_no_preview(self):
"""String items without 3D extensions should not create a preview."""
outputs = {
'node1': {
'result': ['data.json', None]
}
}
count, preview = get_outputs_summary(outputs)
assert count == 0
assert preview is None
def test_string_3d_filename_used_as_fallback(self):
"""String 3D preview should be used when no dict items are previewable."""
outputs = {
'node1': {
'latents': [{'filename': 'latent.safetensors'}],
},
'node2': {
'result': ['model.glb', None]
}
}
count, preview = get_outputs_summary(outputs)
assert preview is not None
assert preview['filename'] == 'model.glb'
assert preview['mediaType'] == '3d'
class TestHas3DExtension:
"""Unit tests for has_3d_extension()"""
def test_recognized_extensions(self):
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
assert has_3d_extension(f'model{ext}') is True
def test_case_insensitive(self):
assert has_3d_extension('MODEL.GLB') is True
assert has_3d_extension('Scene.GLTF') is True
def test_non_3d_extensions(self):
for name in ['photo.png', 'video.mp4', 'data.json', 'model']:
assert has_3d_extension(name) is False
class TestApplySorting:
"""Unit tests for apply_sorting()"""
@@ -395,3 +456,142 @@ class TestNormalizeHistoryItem:
'prompt': {'nodes': {'1': {}}},
'extra_data': {'create_time': 1234567890, 'client_id': 'abc'},
}
def test_include_outputs_normalizes_3d_strings(self):
"""Detail view should transform string 3D filenames into file output dicts."""
history_item = {
'prompt': (
5,
'prompt-3d',
{'nodes': {}},
{'create_time': 1234567890},
['node1'],
),
'status': {'status_str': 'success', 'completed': True, 'messages': []},
'outputs': {
'node1': {
'result': ['preview3d_abc123.glb', None, None]
}
},
}
job = normalize_history_item('prompt-3d', history_item, include_outputs=True)
assert job['outputs_count'] == 1
result_items = job['outputs']['node1']['result']
assert len(result_items) == 1
assert result_items[0] == {
'filename': 'preview3d_abc123.glb',
'type': 'output',
'subfolder': '',
'mediaType': '3d',
}
def test_include_outputs_preserves_dict_items(self):
"""Detail view normalization should pass dict items through unchanged."""
history_item = {
'prompt': (
5,
'prompt-img',
{'nodes': {}},
{'create_time': 1234567890},
['node1'],
),
'status': {'status_str': 'success', 'completed': True, 'messages': []},
'outputs': {
'node1': {
'images': [
{'filename': 'photo.png', 'type': 'output', 'subfolder': ''},
]
}
},
}
job = normalize_history_item('prompt-img', history_item, include_outputs=True)
assert job['outputs_count'] == 1
assert job['outputs']['node1']['images'] == [
{'filename': 'photo.png', 'type': 'output', 'subfolder': ''},
]
class TestNormalizeOutputItem:
"""Unit tests for normalize_output_item()"""
def test_none_returns_none(self):
assert normalize_output_item(None) is None
def test_string_3d_extension_synthesizes_dict(self):
result = normalize_output_item('model.glb')
assert result == {'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
def test_string_non_3d_extension_returns_none(self):
assert normalize_output_item('data.json') is None
def test_string_no_extension_returns_none(self):
assert normalize_output_item('camera_info_string') is None
def test_dict_passes_through(self):
item = {'filename': 'test.png', 'type': 'output'}
assert normalize_output_item(item) is item
def test_other_types_return_none(self):
assert normalize_output_item(42) is None
assert normalize_output_item(True) is None
class TestNormalizeOutputs:
"""Unit tests for normalize_outputs()"""
def test_empty_outputs(self):
assert normalize_outputs({}) == {}
def test_dict_items_pass_through(self):
outputs = {
'node1': {
'images': [{'filename': 'a.png', 'type': 'output'}],
}
}
result = normalize_outputs(outputs)
assert result == outputs
def test_3d_string_synthesized(self):
outputs = {
'node1': {
'result': ['model.glb', None, None],
}
}
result = normalize_outputs(outputs)
assert result == {
'node1': {
'result': [
{'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'},
],
}
}
def test_animated_key_preserved(self):
outputs = {
'node1': {
'images': [{'filename': 'a.png', 'type': 'output'}],
'animated': [True],
}
}
result = normalize_outputs(outputs)
assert result['node1']['animated'] == [True]
def test_non_dict_node_outputs_preserved(self):
outputs = {'node1': 'unexpected_value'}
result = normalize_outputs(outputs)
assert result == {'node1': 'unexpected_value'}
def test_none_items_filtered_but_other_types_preserved(self):
outputs = {
'node1': {
'result': ['data.json', None, [1, 2, 3]],
}
}
result = normalize_outputs(outputs)
assert result == {
'node1': {
'result': ['data.json', [1, 2, 3]],
}
}

View File

@@ -6,7 +6,6 @@ from .tools import VariantSupport
from comfy_execution.graph_utils import GraphBuilder
from comfy.comfy_types.node_typing import ComfyNodeABC
from comfy.comfy_types import IO
from comfy_execution.utils import get_executing_context
class TestLazyMixImages:
@classmethod
@@ -483,57 +482,6 @@ class TestOutputNodeWithSocketOutput:
result = image * value
return (result,)
class TestExpectedOutputs:
"""Test node for the expected_outputs feature.
This node has 3 IMAGE outputs that encode which outputs were expected:
- White image (255) if the output was in expected_outputs
- Black image (0) if the output was NOT in expected_outputs
This allows integration tests to verify which outputs were expected by checking pixel values.
"""
LAZY_OUTPUTS = True # Opt into cache invalidation on output connection changes
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"height": ("INT", {"default": 64, "min": 1, "max": 1024}),
"width": ("INT", {"default": 64, "min": 1, "max": 1024}),
},
}
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE")
RETURN_NAMES = ("output0", "output1", "output2")
FUNCTION = "execute"
CATEGORY = "_for_testing"
def execute(self, height, width):
ctx = get_executing_context()
# Default: assume all outputs are expected (backwards compatibility)
output0_expected = True
output1_expected = True
output2_expected = True
if ctx is not None and ctx.expected_outputs is not None:
output0_expected = 0 in ctx.expected_outputs
output1_expected = 1 in ctx.expected_outputs
output2_expected = 2 in ctx.expected_outputs
# Return white image if expected, black if not
# This allows tests to verify which outputs were expected via pixel values
white = torch.ones(1, height, width, 3)
black = torch.zeros(1, height, width, 3)
return (
white if output0_expected else black,
white if output1_expected else black,
white if output2_expected else black,
)
TEST_NODE_CLASS_MAPPINGS = {
"TestLazyMixImages": TestLazyMixImages,
"TestVariadicAverage": TestVariadicAverage,
@@ -550,7 +498,6 @@ TEST_NODE_CLASS_MAPPINGS = {
"TestSleep": TestSleep,
"TestParallelSleep": TestParallelSleep,
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
"TestExpectedOutputs": TestExpectedOutputs,
}
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
@@ -569,5 +516,4 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestSleep": "Test Sleep",
"TestParallelSleep": "Test Parallel Sleep",
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
"TestExpectedOutputs": "Test Expected Outputs",
}