Compare commits

..

83 Commits

Author SHA1 Message Date
comfyanonymous
265adad858 ComfyUI version v0.3.68 2025-11-04 19:42:23 -05:00
comfyanonymous
7f3e4d486c Limit amount of pinned memory on windows to prevent issues. (#10638) 2025-11-04 17:37:50 -05:00
rattus
a389ee01bb caching: Handle None outputs tuple case (#10637) 2025-11-04 14:14:10 -08:00
ComfyUI Wiki
9c71a66790 chore: update workflow templates to v0.2.11 (#10634) 2025-11-04 10:51:53 -08:00
comfyanonymous
af4b7b5edb More fp8 torch.compile regressions fixed. (#10625) 2025-11-03 22:14:20 -05:00
comfyanonymous
0f4ef3afa0 This seems to slow things down slightly on Linux. (#10624) 2025-11-03 21:47:14 -05:00
comfyanonymous
6b88478f9f Bring back fp8 torch compile performance to what it should be. (#10622) 2025-11-03 19:22:10 -05:00
comfyanonymous
e199c8cc67 Fixes (#10621) 2025-11-03 17:58:24 -05:00
comfyanonymous
0652cb8e2d Speed up torch.compile (#10620) 2025-11-03 17:37:12 -05:00
comfyanonymous
958a17199a People should update their pytorch versions. (#10618) 2025-11-03 17:08:30 -05:00
ComfyUI Wiki
e974e554ca chore: update embedded docs to v0.3.1 (#10614) 2025-11-03 10:59:44 -08:00
Alexander Piskun
4e2110c794 feat(Pika-API-nodes): use new API client (#10608) 2025-11-03 00:29:08 -08:00
Alexander Piskun
e617cddf24 convert nodes_openai.py to V3 schema (#10604) 2025-11-03 00:28:13 -08:00
Alexander Piskun
1f3f7a2823 convert nodes_hypernetwork.py to V3 schema (#10583) 2025-11-03 00:21:47 -08:00
EverNebula
88df172790 fix(caching): treat bytes as hashable (#10567) 2025-11-03 00:16:40 -08:00
Alexander Piskun
6d6a18b0b7 fix(api-nodes-cloud): stop using sub-folder and absolute path for output of Rodin3D nodes (#10556) 2025-11-03 00:04:56 -08:00
comfyanonymous
97ff9fae7e Clarify help text for --fast argument (#10609)
Updated help text for the --fast argument to clarify potential risks.
2025-11-02 13:14:04 -05:00
rattus
135fa49ec2 Small speed improvements to --async-offload (#10593)
* ops: dont take an offload stream if you dont need one

* ops: prioritize mem transfer

The async offload streams reason for existence is to transfer from
RAM to GPU. The post processing compute steps are a bonus on the side
stream, but if the compute stream is running a long kernel, it can
stall the side stream, as it wait to type-cast the bias before
transferring the weight. So do a pure xfer of the weight straight up,
then do everything bias, then go back to fix the weight type and do
weight patches.
2025-11-01 18:48:53 -04:00
comfyanonymous
44869ff786 Fix issue with pinned memory. (#10597) 2025-11-01 17:25:59 -04:00
Alexander Piskun
20182a393f convert StabilityAI to use new API client (#10582) 2025-11-01 12:14:06 -07:00
Alexander Piskun
5f109fe6a0 added 12s-20s as available output durations for the LTXV API nodes (#10570) 2025-11-01 12:13:39 -07:00
comfyanonymous
c58c13b2ba Fix torch compile regression on fp8 ops. (#10580) 2025-11-01 00:25:17 -04:00
comfyanonymous
7f374e42c8 ScaleROPE now works on Lumina models. (#10578) 2025-10-31 15:41:40 -04:00
comfyanonymous
27d1bd8829 Fix rope scaling. (#10560) 2025-10-30 22:51:58 -04:00
comfyanonymous
614cf9805e Add a ScaleROPE node. Currently only works on WAN models. (#10559) 2025-10-30 22:11:38 -04:00
rattus
513b0c46fb Add RAM Pressure cache mode (#10454)
* execution: Roll the UI cache into the outputs

Currently the UI cache is parallel to the output cache with
expectations of being a content superset of the output cache.
At the same time the UI and output cache are maintained completely
seperately, making it awkward to free the output cache content without
changing the behaviour of the UI cache.

There are two actual users (getters) of the UI cache. The first is
the case of a direct content hit on the output cache when executing a
node. This case is very naturally handled by merging the UI and outputs
cache.

The second case is the history JSON generation at the end of the prompt.
This currently works by asking the cache for all_node_ids and then
pulling the cache contents for those nodes. all_node_ids is the nodes
of the dynamic prompt.

So fold the UI cache into the output cache. The current UI cache setter
now writes to a prompt-scope dict. When the output cache is set, just
get this value from the dict and tuple up with the outputs.

When generating the history, simply iterate prompt-scope dict.

This prepares support for more complex caching strategies (like RAM
pressure caching) where less than 1 workflow will be cached and it
will be desirable to keep the UI cache and output cache in sync.

* sd: Implement RAM getter for VAE

* model_patcher: Implement RAM getter for ModelPatcher

* sd: Implement RAM getter for CLIP

* Implement RAM Pressure cache

Implement a cache sensitive to RAM pressure. When RAM headroom drops
down below a certain threshold, evict RAM-expensive nodes from the
cache.

Models and tensors are measured directly for RAM usage. An OOM score
is then computed based on the RAM usage of the node.

Note the due to indirection through shared objects (like a model
patcher), multiple nodes can account the same RAM as their individual
usage. The intent is this will free chains of nodes particularly
model loaders and associate loras as they all score similar and are
sorted in close to each other.

Has a bias towards unloading model nodes mid flow while being able
to keep results like text encodings and VAE.

* execution: Convert the cache entry to NamedTuple

As commented in review.

Convert this to a named tuple and abstract away the tuple type
completely from graph.py.
2025-10-30 17:39:02 -04:00
Alexander Piskun
dfac94695b fix img2img operation in Dall2 node (#10552) 2025-10-30 10:22:35 -07:00
Alexander Piskun
163b629c70 use new API client in Pixverse and Ideogram nodes (#10543) 2025-10-29 23:49:03 -07:00
Jedrzej Kosinski
998bf60beb Add units/info for the numbers displayed on 'load completely' and 'load partially' log messages (#10538) 2025-10-29 19:37:06 -04:00
comfyanonymous
906c089957 Fix small performance regression with fp8 fast and scaled fp8. (#10537) 2025-10-29 19:29:01 -04:00
comfyanonymous
25de7b1bfa Try to fix slow load issue on low ram hardware with pinned mem. (#10536) 2025-10-29 17:20:27 -04:00
rattus
ab7ab5be23 Fix Race condition in --async-offload that can cause corruption (#10501)
* mm: factor out the current stream getter

Make this a reusable function.

* ops: sync the offload stream with the consumption of w&b

This sync is nessacary as pytorch will queue cuda async frees on the
same stream as created to tensor. In the case of async offload, this
will be on the offload stream.

Weights and biases can go out of scope in python which then
triggers the pytorch garbage collector to queue the free operation on
the offload stream possible before the compute stream has used the
weight. This causes a use after free on weight data leading to total
corruption of some workflows.

So sync the offload stream with the compute stream after the weight
has been used so the free has to wait for the weight to be used.

The cast_bias_weight is extended in a backwards compatible way with
the new behaviour opt-in on a defaulted parameter. This handles
custom node packs calling cast_bias_weight and defeatures
async-offload for them (as they do not handle the race).

The pattern is now:

cast_bias_weight(... , offloadable=True) #This might be offloaded
thing(weight, bias, ...)
uncast_bias_weight(...)

* controlnet: adopt new cast_bias_weight synchronization scheme

This is nessacary for safe async weight offloading.

* mm: sync the last stream in the queue, not the next

Currently this peeks ahead to sync the next stream in the queue of
streams with the compute stream. This doesnt allow a lot of
parallelization, as then end result is you can only get one weight load
ahead regardless of how many streams you have.

Rotate the loop logic here to synchronize the end of the queue before
returning the next stream. This allows weights to be loaded ahead of the
compute streams position.
2025-10-29 17:17:46 -04:00
comfyanonymous
ec4fc2a09a Fix case of weights not being unpinned. (#10533) 2025-10-29 15:48:06 -04:00
comfyanonymous
1a58087ac2 Reduce memory usage for fp8 scaled op. (#10531) 2025-10-29 15:43:51 -04:00
Alexander Piskun
6c14f3afac use new API client in Luma and Minimax nodes (#10528) 2025-10-29 11:14:56 -07:00
comfyanonymous
e525673f72 Fix issue. (#10527) 2025-10-29 00:37:00 -04:00
comfyanonymous
3fa7a5c04a Speed up offloading using pinned memory. (#10526)
To enable this feature use: --fast pinned_memory
2025-10-29 00:21:01 -04:00
Alexander Piskun
210f7a1ba5 convert nodes_recraft.py to V3 schema (#10507) 2025-10-28 14:38:05 -07:00
rattus
d202c2ba74 execution: Allow a subgraph nodes to execute multiple times (#10499)
In the case of --cache-none lazy and subgraph execution can cause
anything to be run multiple times per workflow. If that rerun nodes is
in itself a subgraph generator, this will crash for two reasons.

pending_subgraph_results[] does not cleanup entries after their use.
So when a pending_subgraph_result is consumed, remove it from the list
so that if the corresponding node is fully re-executed this misses
lookup and it fall through to execute the node as it should.

Secondly, theres is an explicit enforcement against dups in the
addition of subgraphs nodes as ephemerals to the dymprompt. Remove this
enforcement as the use case is now valid.
2025-10-28 16:22:08 -04:00
contentis
8817f8fc14 Mixed Precision Quantization System (#10498)
* Implement mixed precision operations with a registry design and metadate for quant spec in checkpoint.

* Updated design using Tensor Subclasses

* Fix FP8 MM

* An actually functional POC

* Remove CK reference and ensure correct compute dtype

* Update unit tests

* ruff lint

* Implement mixed precision operations with a registry design and metadate for quant spec in checkpoint.

* Updated design using Tensor Subclasses

* Fix FP8 MM

* An actually functional POC

* Remove CK reference and ensure correct compute dtype

* Update unit tests

* ruff lint

* Fix missing keys

* Rename quant dtype parameter

* Rename quant dtype parameter

* Fix unittests for CPU build
2025-10-28 16:20:53 -04:00
comfyanonymous
22e40d2ace Tell users to update their nvidia drivers if portable doesn't start. (#10518) 2025-10-28 15:08:08 -04:00
comfyanonymous
3bea4efc6b Tell users to update nvidia drivers if problem with portable. (#10510) 2025-10-28 04:45:45 -04:00
comfyanonymous
8cf2ba4ba6 Remove comfy api key from queue api. (#10502) 2025-10-28 03:23:52 -04:00
comfyanonymous
b61a40cbc9 Bump stable portable to cu130 python 3.13.9 (#10508) 2025-10-28 03:21:45 -04:00
comfyanonymous
f2bb3230b7 ComfyUI version v0.3.67 2025-10-28 03:03:59 -04:00
Jedrzej Kosinski
614b8d3345 frontend bump to 1.28.8 (#10506) 2025-10-28 03:01:13 -04:00
ComfyUI Wiki
6abc30aae9 Update template to 0.2.4 (#10505) 2025-10-28 01:56:30 -04:00
Alexander Piskun
55bad30375 feat(api-nodes): add LTXV API nodes (#10496) 2025-10-27 22:25:29 -07:00
ComfyUI Wiki
c305deed56 Update template to 0.2.3 (#10503) 2025-10-27 22:24:16 -07:00
comfyanonymous
601ee1775a Add a bat to run comfyui portable without api nodes. (#10504) 2025-10-27 23:54:00 -04:00
comfyanonymous
c170fd2db5 Bump portable deps workflow to torch cu130 python 3.13.9 (#10493) 2025-10-26 20:23:01 -04:00
Alexander Piskun
9d529e5308 fix(api-nodes): random issues on Windows by capturing general OSError for retries (#10486) 2025-10-25 23:51:06 -07:00
comfyanonymous
f6bbc1ac84 Fix mistake. (#10484) 2025-10-25 23:07:29 -04:00
comfyanonymous
098a352f13 Add warning for torch-directml usage (#10482)
Added a warning message about the state of torch-directml.
2025-10-25 20:05:22 -04:00
Alexander Piskun
e86b79ab9e convert Gemini API nodes to V3 schema (#10476) 2025-10-25 14:35:30 -07:00
comfyanonymous
426cde37f1 Remove useless function (#10472) 2025-10-24 19:56:51 -04:00
Alexander Piskun
dd5af0c587 convert Tripo API nodes to V3 schema (#10469) 2025-10-24 15:48:34 -07:00
Alexander Piskun
388b306a2b feat(api-nodes): network client v2: async ops, cancellation, downloads, refactor (#10390)
* feat(api-nodes): implement new API client for V3 nodes

* feat(api-nodes): implement new API client for V3 nodes

* feat(api-nodes): implement new API client for V3 nodes

* converted WAN nodes to use new client; polishing

* fix(auth): do not leak authentification for the absolute urls

* convert BFL API nodes to use new API client; remove deprecated BFL nodes

* converted Google Veo nodes

* fix(Veo3.1 model): take into account "generate_audio" parameter
2025-10-23 22:37:16 -07:00
ComfyUI Wiki
24188b3141 Update template to 0.2.2 (#10461)
Fix template typo issue
2025-10-24 01:36:30 -04:00
comfyanonymous
1bcda6df98 WIP way to support multi multi dimensional latents. (#10456) 2025-10-23 21:21:14 -04:00
comfyanonymous
a1864c01f2 Small readme improvement. (#10442) 2025-10-22 17:26:22 -04:00
rattus
4739d7717f execution: fold in dependency aware caching / Fix --cache-none with loops/lazy etc (Resubmit) (#10440)
* execution: fold in dependency aware caching

This makes --cache-none compatiable with lazy and expanded
subgraphs.

Currently the --cache-none option is powered by the
DependencyAwareCache. The cache attempts to maintain a parallel
copy of the execution list data structure, however it is only
setup once at the start of execution and does not get meaninigful
updates to the execution list.

This causes multiple problems when --cache-none is used with lazy
and expanded subgraphs as the DAC does not accurately update its
copy of the execution data structure.

DAC has an attempt to handle subgraphs ensure_subcache however
this does not accurately connect to nodes outside the subgraph.
The current semantics of DAC are to free a node ASAP after the
dependent nodes are executed.

This means that if a subgraph refs such a node it will be requed
and re-executed by the execution_list but DAC wont see it in
its to-free lists anymore and leak memory.

Rather than try and cover all the cases where the execution list
changes from inside the cache, move the while problem to the
executor which maintains an always up-to-date copy of the wanted
data-structure.

The executor now has a fast-moving run-local cache of its own.
Each _to node has its own mini cache, and the cache is unconditionally
primed at the time of add_strong_link.

add_strong_link is called for all of static workflows, lazy links
and expanded subgraphs so its the singular source of truth for
output dependendencies.

In the case of a cache-hit, the executor cache will hold the non-none
value (it will respect updates if they happen somehow as well).

In the case of a cache-miss, the executor caches a None and will
wait for a notification to update the value when the node completes.

When a node completes execution, it simply releases its mini-cache
and in turn its strong refs on its direct anscestor outputs, allowing
for ASAP freeing (same as the DependencyAwareCache but a little more
automatic).

This now allows for re-implementation of --cache-none with no cache
at all. The dependency aware cache was also observing the dependency
sematics for the objects and UI cache which is not accurate (this
entire logic was always outputs specific).

This also prepares for more complex caching strategies (such as RAM
pressure based caching), where a cache can implement any freeing
strategy completely independently of the DepedancyAwareness
requirement.

* main: re-implement --cache-none as no cache at all

The execution list now tracks the dependency aware caching more
correctly that the DependancyAwareCache.

Change it to a cache that does nothing.

* test_execution: add --cache-none to the test suite

--cache-none is now expected to work universally. Run it through the
full unit test suite. Propagate the server parameterization for whether
or not the server is capabale of caching, so that the minority of tests
that specifically check for cache hits can if else. Hard assert NOT
caching in the else to give some coverage of --cache-none expected
behaviour to not acutally cache.
2025-10-22 15:49:05 -04:00
Jedrzej Kosinski
f13cff0be6 Add custom node published subgraphs endpoint (#10438)
* Add get_subgraphs_dir to ComfyExtension and PUBLISHED_SUBGRAPH_DIRS to nodes.py

* Created initial endpoints, although the returned paths are a bit off currently

* Fix path and actually return real data

* Sanitize returned /api/global_subgraphs entries

* Remove leftover function from early prototyping

* Remove added whitespace

* Add None check for sanitize_entry
2025-10-21 23:16:16 -04:00
comfyanonymous
9cdc64998f Only disable cudnn on newer AMD GPUs. (#10437) 2025-10-21 19:15:23 -04:00
comfyanonymous
560b1bdfca ComfyUI version v0.3.66 2025-10-21 01:12:32 -04:00
comfyanonymous
b7992f871a Revert "execution: fold in dependency aware caching / Fix --cache-none with l…" (#10422)
This reverts commit b1467da480.
2025-10-20 19:03:06 -04:00
comfyanonymous
2c2aa409b0 Log message for cudnn disable on AMD. (#10418) 2025-10-20 15:43:24 -04:00
ComfyUI Wiki
a4787ac83b Update template to 0.2.1 (#10413)
* Update template to 0.1.97

* Update template to 0.2.1
2025-10-20 15:28:36 -04:00
Christian Byrne
b5c59b763c Deprecation warning on unused files (#10387)
* only warn for unused files

* include internal extensions
2025-10-19 13:05:46 -07:00
comfyanonymous
b4f30bd408 Pytorch is stupid. (#10398) 2025-10-19 01:25:35 -04:00
comfyanonymous
dad076aee6 Speed up chroma radiance. (#10395) 2025-10-18 23:19:52 -04:00
comfyanonymous
0cf33953a7 Fix batch size above 1 giving bad output in chroma radiance. (#10394) 2025-10-18 23:15:34 -04:00
comfyanonymous
5b80addafd Turn off cuda malloc by default when --fast autotune is turned on. (#10393) 2025-10-18 22:35:46 -04:00
comfyanonymous
9da397ea2f Disable torch compiler for cast_bias_weight function (#10384)
* Disable torch compiler for cast_bias_weight function

* Fix torch compile.
2025-10-17 20:03:28 -04:00
comfyanonymous
92d97380bd Update Python 3.14 installation instructions (#10385)
Removed mention of installing pytorch nightly for Python 3.14.
2025-10-17 18:22:59 -04:00
Alexander Piskun
99ce2a1f66 convert nodes_controlnet.py to V3 schema (#10202) 2025-10-17 14:13:05 -07:00
rattus128
b1467da480 execution: fold in dependency aware caching / Fix --cache-none with loops/lazy etc (#10368)
* execution: fold in dependency aware caching

This makes --cache-none compatiable with lazy and expanded
subgraphs.

Currently the --cache-none option is powered by the
DependencyAwareCache. The cache attempts to maintain a parallel
copy of the execution list data structure, however it is only
setup once at the start of execution and does not get meaninigful
updates to the execution list.

This causes multiple problems when --cache-none is used with lazy
and expanded subgraphs as the DAC does not accurately update its
copy of the execution data structure.

DAC has an attempt to handle subgraphs ensure_subcache however
this does not accurately connect to nodes outside the subgraph.
The current semantics of DAC are to free a node ASAP after the
dependent nodes are executed.

This means that if a subgraph refs such a node it will be requed
and re-executed by the execution_list but DAC wont see it in
its to-free lists anymore and leak memory.

Rather than try and cover all the cases where the execution list
changes from inside the cache, move the while problem to the
executor which maintains an always up-to-date copy of the wanted
data-structure.

The executor now has a fast-moving run-local cache of its own.
Each _to node has its own mini cache, and the cache is unconditionally
primed at the time of add_strong_link.

add_strong_link is called for all of static workflows, lazy links
and expanded subgraphs so its the singular source of truth for
output dependendencies.

In the case of a cache-hit, the executor cache will hold the non-none
value (it will respect updates if they happen somehow as well).

In the case of a cache-miss, the executor caches a None and will
wait for a notification to update the value when the node completes.

When a node completes execution, it simply releases its mini-cache
and in turn its strong refs on its direct anscestor outputs, allowing
for ASAP freeing (same as the DependencyAwareCache but a little more
automatic).

This now allows for re-implementation of --cache-none with no cache
at all. The dependency aware cache was also observing the dependency
sematics for the objects and UI cache which is not accurate (this
entire logic was always outputs specific).

This also prepares for more complex caching strategies (such as RAM
pressure based caching), where a cache can implement any freeing
strategy completely independently of the DepedancyAwareness
requirement.

* main: re-implement --cache-none as no cache at all

The execution list now tracks the dependency aware caching more
correctly that the DependancyAwareCache.

Change it to a cache that does nothing.

* test_execution: add --cache-none to the test suite

--cache-none is now expected to work universally. Run it through the
full unit test suite. Propagate the server parameterization for whether
or not the server is capabale of caching, so that the minority of tests
that specifically check for cache hits can if else. Hard assert NOT
caching in the else to give some coverage of --cache-none expected
behaviour to not acutally cache.
2025-10-17 13:55:15 -07:00
Jedrzej Kosinski
d8d60b5609 Do batch_slice in EasyCache's apply_cache_diff (#10376) 2025-10-17 00:39:37 -04:00
comfyanonymous
b1293d50ef workaround also works on cudnn 91200 (#10375) 2025-10-16 19:59:56 -04:00
comfyanonymous
19b466160c Workaround for nvidia issue where VAE uses 3x more memory on torch 2.9 (#10373) 2025-10-16 18:16:03 -04:00
Alexander Piskun
bc0ad9bb49 fix(api-nodes): remove "veo2" model from Veo3 node (#10372) 2025-10-16 10:12:50 -07:00
Rizumu Ayaka
4054b4bf38 feat: deprecated API alert (#10366) 2025-10-16 01:13:31 -07:00
Arjan Singh
55ac7d333c Bump frontend to 1.28.7 (#10364) 2025-10-15 20:30:39 -07:00
80 changed files with 7727 additions and 6955 deletions

View File

@@ -0,0 +1,3 @@
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@@ -1,2 +1,3 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@@ -1,2 +1,3 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@@ -18,9 +18,9 @@ jobs:
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu129"
cache_tag: "cu130"
python_minor: "13"
python_patch: "6"
python_patch: "9"
rel_name: "nvidia"
rel_extra_name: ""
test_release: true

View File

@@ -17,7 +17,7 @@ on:
description: 'cuda version'
required: true
type: string
default: "129"
default: "130"
python_minor:
description: 'python minor version'
@@ -29,7 +29,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "6"
default: "9"
# push:
# branches:
# - master

View File

@@ -176,6 +176,8 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you
If you have trouble extracting it, right click the file -> properties -> unblock
Update your Nvidia drivers if it doesn't start.
#### Alternative Downloads:
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
@@ -197,10 +199,12 @@ comfy install
## Manual Install (Windows, Linux)
Python 3.14 will work if you comment out the `kornia` dependency in the requirements.txt file (breaks the canny node) and install pytorch nightly but it is not recommended.
Python 3.14 will work if you comment out the `kornia` dependency in the requirements.txt file (breaks the canny node) but it is not recommended.
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
### Instructions:
Git clone this repo.
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints

112
app/subgraph_manager.py Normal file
View File

@@ -0,0 +1,112 @@
from __future__ import annotations
from typing import TypedDict
import os
import folder_paths
import glob
from aiohttp import web
import hashlib
class Source:
custom_node = "custom_node"
class SubgraphEntry(TypedDict):
source: str
"""
Source of subgraph - custom_nodes vs templates.
"""
path: str
"""
Relative path of the subgraph file.
For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
"""
name: str
"""
Name of subgraph file.
"""
info: CustomNodeSubgraphEntryInfo
"""
Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
"""
data: str
class CustomNodeSubgraphEntryInfo(TypedDict):
node_pack: str
"""Node pack name."""
class SubgraphManager:
def __init__(self):
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
async def load_entry_data(self, entry: SubgraphEntry):
with open(entry['path'], 'r') as f:
entry['data'] = f.read()
return entry
async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
if entry is None:
return None
entry = entry.copy()
entry.pop('path', None)
if remove_data:
entry.pop('data', None)
return entry
async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
entries = entries.copy()
for key in list(entries.keys()):
entries[key] = await self.sanitize_entry(entries[key], remove_data)
return entries
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
# if not forced to reload and cached, return cache
if not force_reload and self.cached_custom_node_subgraphs is not None:
return self.cached_custom_node_subgraphs
# Load subgraphs from custom nodes
subfolder = "subgraphs"
subgraphs_dict: dict[SubgraphEntry] = {}
for folder in folder_paths.get_folder_paths("custom_nodes"):
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
matched_files = glob.glob(pattern)
for file in matched_files:
# replace backslashes with forward slashes
file = file.replace('\\', '/')
info: CustomNodeSubgraphEntryInfo = {
"node_pack": "custom_nodes." + file.split('/')[-3]
}
source = Source.custom_node
# hash source + path to make sure id will be as unique as possible, but
# reproducible across backend reloads
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
entry: SubgraphEntry = {
"source": Source.custom_node,
"name": os.path.splitext(os.path.basename(file))[0],
"path": file,
"info": info,
}
subgraphs_dict[id] = entry
self.cached_custom_node_subgraphs = subgraphs_dict
return subgraphs_dict
async def get_custom_node_subgraph(self, id: str, loadedModules):
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
entry: SubgraphEntry = subgraphs.get(id, None)
if entry is not None and entry.get('data', None) is None:
await self.load_entry_data(entry)
return entry
def add_routes(self, routes, loadedModules):
@routes.get("/global_subgraphs")
async def get_global_subgraphs(request):
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
# that's the reasoning for the current implementation
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
@routes.get("/global_subgraphs/{id}")
async def get_global_subgraph(request):
id = request.match_info.get("id", None)
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
return web.json_response(await self.sanitize_entry(subgraph))

View File

@@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use. All other devices will not be visible.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
@@ -105,6 +105,7 @@ cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
@@ -144,8 +145,9 @@ class PerformanceFeature(enum.Enum):
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
AutoTune = "autotune"
PinnedMem = "pinned_memory"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")

View File

@@ -15,14 +15,13 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import torch
from enum import Enum
import math
import os
import logging
import copy
import comfy.utils
import comfy.model_management
import comfy.model_detection
@@ -39,7 +38,7 @@ import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet
import comfy.ldm.qwen_image.controlnet
import comfy.cldm.dit_embedder
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.hooks import HookGroup
@@ -65,18 +64,6 @@ class StrengthType(Enum):
CONSTANT = 1
LINEAR_UP = 2
class ControlIsolation:
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
def __init__(self, control: ControlBase):
self.control = control
self.orig_previous_controlnet = control.previous_controlnet
def __enter__(self):
self.control.previous_controlnet = None
def __exit__(self, *args):
self.control.previous_controlnet = self.orig_previous_controlnet
class ControlBase:
def __init__(self):
self.cond_hint_original = None
@@ -90,7 +77,7 @@ class ControlBase:
self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact'
self.extra_args = {}
self.previous_controlnet: Union[ControlBase, None] = None
self.previous_controlnet = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
self.concat_mask = False
@@ -98,7 +85,6 @@ class ControlBase:
self.extra_concat = None
self.extra_hooks: HookGroup = None
self.preprocess_image = lambda a: a
self.multigpu_clones: dict[torch.device, ControlBase] = {}
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint
@@ -125,38 +111,17 @@ class ControlBase:
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
for device_cnet in self.multigpu_clones.values():
with ControlIsolation(device_cnet):
device_cnet.cleanup()
self.cond_hint = None
self.extra_concat = None
self.timestep_range = None
def get_models(self):
out = []
for device_cnet in self.multigpu_clones.values():
out += device_cnet.get_models_only_self()
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models()
return out
def get_models_only_self(self):
'Calls get_models, but temporarily sets previous_controlnet to None.'
with ControlIsolation(self):
return self.get_models()
def get_instance_for_device(self, device):
'Returns instance of this Control object intended for selected device.'
return self.multigpu_clones.get(device, self)
def deepclone_multigpu(self, load_device, autoregister=False):
'''
Create deep clone of Control object where model(s) is set to other devices.
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
'''
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
def get_extra_hooks(self):
out = []
if self.extra_hooks is not None:
@@ -165,7 +130,7 @@ class ControlBase:
out += self.previous_controlnet.get_extra_hooks()
return out
def copy_to(self, c: ControlBase):
def copy_to(self, c):
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range
@@ -319,14 +284,6 @@ class ControlNet(ControlBase):
self.copy_to(c)
return c
def deepclone_multigpu(self, load_device, autoregister=False):
c = self.copy()
c.control_model = copy.deepcopy(c.control_model)
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
if autoregister:
self.multigpu_clones[load_device] = c
return c
def get_models(self):
out = super().get_models()
out.append(self.control_model_wrapped)
@@ -353,11 +310,13 @@ class ControlLoraOps:
self.bias = None
def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input)
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None:
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else:
return torch.nn.functional.linear(input, weight, bias)
x = torch.nn.functional.linear(input, weight, bias)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(
@@ -393,12 +352,13 @@ class ControlLoraOps:
def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input)
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None:
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else:
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
@@ -872,14 +832,6 @@ class T2IAdapter(ControlBase):
self.copy_to(c)
return c
def deepclone_multigpu(self, load_device, autoregister=False):
c = self.copy()
c.t2i_model = copy.deepcopy(c.t2i_model)
c.device = load_device
if autoregister:
self.multigpu_clones[load_device] = c
return c
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
compression_ratio = 8
upscale_algorithm = 'nearest-exact'

View File

@@ -189,15 +189,15 @@ class ChromaRadiance(Chroma):
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
# Reshape for per-patch processing
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
# Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
# the tile size.
img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params)
img_dct = self.forward_tiled_nerf(nerf_hidden, nerf_pixels, B, C, num_patches, patch_size, params)
else:
# Reshape for per-patch processing
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
# Get DCT-encoded pixel embeddings [pixel-dct]
img_dct = self.nerf_image_embedder(nerf_pixels)
@@ -240,17 +240,8 @@ class ChromaRadiance(Chroma):
end = min(i + tile_size, num_patches)
# Slice the current tile from the input tensors
nerf_hidden_tile = nerf_hidden[:, i:end, :]
nerf_pixels_tile = nerf_pixels[:, i:end, :]
# Get the actual number of patches in this tile (can be smaller for the last tile)
num_patches_tile = nerf_hidden_tile.shape[1]
# Reshape the tile for per-patch processing
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size)
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
nerf_hidden_tile = nerf_hidden[i * batch:end * batch]
nerf_pixels_tile = nerf_pixels[i * batch:end * batch]
# get DCT-encoded pixel embeddings [pixel-dct]
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)

View File

@@ -522,7 +522,7 @@ class NextDiT(nn.Module):
max_cap_len = max(l_effective_cap_len)
max_img_len = max(l_effective_img_len)
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
@@ -531,10 +531,22 @@ class NextDiT(nn.Module):
H_tokens, W_tokens = H // pH, W // pW
assert H_tokens * W_tokens == img_len
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids

View File

@@ -588,7 +588,7 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes)
return x
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
@@ -601,10 +601,22 @@ class WanModel(torch.nn.Module):
if steps_w is None:
steps_w = w_len
h_start = 0
w_start = 0
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
t_start += rope_options.get("shift_t", 0.0)
h_start += rope_options.get("shift_y", 0.0)
w_start += rope_options.get("shift_x", 0.0)
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
freqs = self.rope_embedder(img_ids).movedim(1, 2)
@@ -630,7 +642,7 @@ class WanModel(torch.nn.Module):
if self.ref_conv is not None and "reference_latent" in kwargs:
t_len += 1
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):

View File

@@ -134,7 +134,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
@@ -197,8 +197,14 @@ class BaseModel(torch.nn.Module):
extra_conds[o] = extra
t = self.process_timestep(t, x=x, **extra_conds)
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
if "latent_shapes" in extra_conds:
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
if len(model_output) > 1 and not torch.is_tensor(model_output):
model_output, _ = utils.pack_latents(model_output)
return self.model_sampling.calculate_denoised(sigma, model_output.float(), x)
def process_timestep(self, timestep, **kwargs):
return timestep
@@ -327,6 +333,14 @@ class BaseModel(torch.nn.Module):
if self.model_config.scaled_fp8 is not None:
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
# Save mixed precision metadata
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
metadata = {
"format_version": "1.0",
"layers": self.model_config.layer_quant_config
}
unet_state_dict["_quantization_metadata"] = metadata
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION:

View File

@@ -6,6 +6,20 @@ import math
import logging
import torch
def detect_layer_quantization(metadata):
quant_key = "_quantization_metadata"
if metadata is not None and quant_key in metadata:
quant_metadata = metadata.pop(quant_key)
quant_metadata = json.loads(quant_metadata)
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
return quant_metadata["layers"]
else:
raise ValueError("Invalid quantization metadata format")
return None
def count_blocks(state_dict_keys, prefix_string):
count = 0
while True:
@@ -213,7 +227,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_mlp_ratio"] = 4
dit_config["nerf_depth"] = 4
dit_config["nerf_max_freqs"] = 8
dit_config["nerf_tile_size"] = 32
dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
else:
@@ -701,6 +715,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
else:
model_config.optimizations["fp8"] = True
# Detect per-layer quantization (mixed precision)
layer_quant_config = detect_layer_quantization(metadata)
if layer_quant_config:
model_config.layer_quant_config = layer_quant_config
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
return model_config
def unet_prefix_from_state_dict(state_dict):

View File

@@ -15,7 +15,6 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import psutil
import logging
@@ -28,10 +27,6 @@ import platform
import weakref
import gc
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram
@@ -94,6 +89,7 @@ if args.deterministic:
directml_enabled = False
if args.directml is not None:
logging.warning("WARNING: torch-directml barely works, is very slow, has not been updated in over 1 year and might be removed soon, please don't use it, there are better options.")
import torch_directml
directml_enabled = True
device_index = args.directml
@@ -191,25 +187,6 @@ def get_torch_device():
else:
return torch.device(torch.cuda.current_device())
def get_all_torch_devices(exclude_current=False):
global cpu_state
devices = []
if cpu_state == CPUState.GPU:
if is_nvidia():
for i in range(torch.cuda.device_count()):
devices.append(torch.device(i))
elif is_intel_xpu():
for i in range(torch.xpu.device_count()):
devices.append(torch.device(i))
elif is_ascend_npu():
for i in range(torch.npu.device_count()):
devices.append(torch.device(i))
else:
devices.append(get_torch_device())
if exclude_current:
devices.remove(get_torch_device())
return devices
def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled
if dev is None:
@@ -354,14 +331,21 @@ except:
SUPPORT_FP8_OPS = args.supports_fp8_compute
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
try:
if is_amd():
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
try:
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
except:
rocm_version = (6, -1)
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
logging.info("AMD arch: {}".format(arch))
logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
@@ -395,6 +379,9 @@ try:
except:
pass
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
torch.backends.cudnn.benchmark = True
try:
if torch_version_numeric >= (2, 5):
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
@@ -457,13 +444,9 @@ try:
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
except:
logging.warning("Could not pick default device.")
try:
for device in get_all_torch_devices(exclude_current=True):
logging.info("Device: {}".format(get_torch_device_name(device)))
except:
pass
current_loaded_models: list[LoadedModel] = []
current_loaded_models = []
def module_size(module):
module_mem = 0
@@ -474,7 +457,7 @@ def module_size(module):
return module_mem
class LoadedModel:
def __init__(self, model: ModelPatcher):
def __init__(self, model):
self._set_model(model)
self.device = model.load_device
self.real_model = None
@@ -482,7 +465,7 @@ class LoadedModel:
self.model_finalizer = None
self._patcher_finalizer = None
def _set_model(self, model: ModelPatcher):
def _set_model(self, model):
self._model = weakref.ref(model)
if model.parent is not None:
self._parent_model = weakref.ref(model.parent)
@@ -1016,12 +999,6 @@ def device_supports_non_blocking(device):
return False
return True
def device_should_use_non_blocking(device):
if not device_supports_non_blocking(device):
return False
return False
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
def force_channels_last():
if args.force_channels_last:
return True
@@ -1036,6 +1013,16 @@ if args.async_offload:
NUM_STREAMS = 2
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
def current_stream(device):
if device is None:
return None
if is_device_cuda(device):
return torch.cuda.current_stream()
elif is_device_xpu(device):
return torch.xpu.current_stream()
else:
return None
stream_counters = {}
def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
@@ -1044,21 +1031,17 @@ def get_offload_stream(device):
if device in STREAMS:
ss = STREAMS[device]
s = ss[stream_counter]
#Sync the oldest stream in the queue with the current
ss[stream_counter].wait_stream(current_stream(device))
stream_counter = (stream_counter + 1) % len(ss)
if is_device_cuda(device):
ss[stream_counter].wait_stream(torch.cuda.current_stream())
elif is_device_xpu(device):
ss[stream_counter].wait_stream(torch.xpu.current_stream())
stream_counters[device] = stream_counter
return s
return ss[stream_counter]
elif is_device_cuda(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.cuda.Stream(device=device, priority=0))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
elif is_device_xpu(device):
@@ -1067,18 +1050,14 @@ def get_offload_stream(device):
ss.append(torch.xpu.Stream(device=device, priority=0))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
return None
def sync_stream(device, stream):
if stream is None:
if stream is None or current_stream(device) is None:
return
if is_device_cuda(device):
torch.cuda.current_stream().wait_stream(stream)
elif is_device_xpu(device):
torch.xpu.current_stream().wait_stream(stream)
current_stream(device).wait_stream(stream)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if device is None or weight.device == device:
@@ -1103,6 +1082,60 @@ def cast_to_device(tensor, device, dtype, copy=False):
non_blocking = device_supports_non_blocking(device)
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
PINNED_MEMORY = {}
TOTAL_PINNED_MEMORY = 0
if PerformanceFeature.PinnedMem in args.fast:
if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
else:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
else:
MAX_PINNED_MEMORY = -1
def pin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if not is_nvidia():
return False
if not is_device_cpu(tensor.device):
return False
size = tensor.numel() * tensor.element_size()
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
return False
ptr = tensor.data_ptr()
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
PINNED_MEMORY[ptr] = size
TOTAL_PINNED_MEMORY += size
return True
return False
def unpin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if not is_nvidia():
return False
if not is_device_cpu(tensor.device):
return False
ptr = tensor.data_ptr()
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
if len(PINNED_MEMORY) == 0:
TOTAL_PINNED_MEMORY = 0
return True
return False
def sage_attention_enabled():
return args.use_sage_attention
@@ -1355,7 +1388,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_amd():
arch = torch.cuda.get_device_properties(device).gcnArchName
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
if any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH): # RDNA2 and older don't support bf16
if manual_cast:
return True
return False
@@ -1424,34 +1457,8 @@ def soft_empty_cache(force=False):
torch.cuda.ipc_collect()
def unload_all_models():
for device in get_all_torch_devices():
free_memory(1e30, device)
free_memory(1e30, get_torch_device())
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
'Unload only model and its clones - primarily for multigpu cloning purposes.'
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
additional_models = []
if unload_additional_models:
additional_models = model.get_nested_additional_models()
keep_loaded = []
for loaded_model in initial_keep_loaded:
if loaded_model.model is not None:
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
continue
# check additional models if they are a match
skip = False
for add_model in additional_models:
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
skip = True
break
if skip:
continue
keep_loaded.append(loaded_model)
if not all_devices:
free_memory(1e30, get_torch_device(), keep_loaded)
else:
for device in get_all_torch_devices():
free_memory(1e30, device, keep_loaded)
#TODO: might be cleaner to put this somewhere else
import threading

View File

@@ -87,15 +87,12 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
def create_model_options_clone(orig_model_options: dict):
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
def create_hook_patches_clone(orig_hook_patches):
new_hook_patches = {}
for hook_ref in orig_hook_patches:
new_hook_patches[hook_ref] = {}
for k in orig_hook_patches[hook_ref]:
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
if copy_tuples:
for i in range(len(new_hook_patches[hook_ref][k])):
new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
return new_hook_patches
def wipe_lowvram_weight(m):
@@ -241,6 +238,7 @@ class ModelPatcher:
self.force_cast_weights = False
self.patches_uuid = uuid.uuid4()
self.parent = None
self.pinned = set()
self.attachments: dict[str] = {}
self.additional_models: dict[str, list[ModelPatcher]] = {}
@@ -260,9 +258,6 @@ class ModelPatcher:
self.is_clip = False
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
self.is_multigpu_base_clone = False
self.clone_base_uuid = uuid.uuid4()
if not hasattr(self.model, 'model_loaded_weight_memory'):
self.model.model_loaded_weight_memory = 0
@@ -281,6 +276,9 @@ class ModelPatcher:
self.size = comfy.model_management.module_size(self.model)
return self.size
def get_ram_usage(self):
return self.model_size()
def loaded_size(self):
return self.model.model_loaded_weight_memory
@@ -300,6 +298,7 @@ class ModelPatcher:
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self
n.pinned = self.pinned
n.force_cast_weights = self.force_cast_weights
@@ -341,92 +340,18 @@ class ModelPatcher:
n.is_clip = self.is_clip
n.hook_mode = self.hook_mode
n.is_multigpu_base_clone = self.is_multigpu_base_clone
n.clone_base_uuid = self.clone_base_uuid
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
callback(self, n)
return n
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
comfy.model_management.unload_model_and_clones(self)
n = self.clone()
# set load device, if present
if new_load_device is not None:
n.load_device = new_load_device
# unlike for normal clone, backup dicts that shared same ref should not;
# otherwise, patchers that have deep copies of base models will erroneously influence each other.
n.backup = copy.deepcopy(n.backup)
n.object_patches_backup = copy.deepcopy(n.object_patches_backup)
n.hook_backup = copy.deepcopy(n.hook_backup)
n.model = copy.deepcopy(n.model)
# multigpu clone should not have multigpu additional_models entry
n.remove_additional_models("multigpu")
# multigpu_clone all stored additional_models; make sure circular references are properly handled
if models_cache is None:
models_cache = {}
for key, model_list in n.additional_models.items():
for i in range(len(model_list)):
add_model = n.additional_models[key][i]
if add_model.clone_base_uuid not in models_cache:
models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
callback(self, n)
return n
def match_multigpu_clones(self):
multigpu_models = self.get_additional_models_with_key("multigpu")
if len(multigpu_models) > 0:
new_multigpu_models = []
for mm in multigpu_models:
# clone main model, but bring over relevant props from existing multigpu clone
n = self.clone()
n.load_device = mm.load_device
n.backup = mm.backup
n.object_patches_backup = mm.object_patches_backup
n.hook_backup = mm.hook_backup
n.model = mm.model
n.is_multigpu_base_clone = mm.is_multigpu_base_clone
n.remove_additional_models("multigpu")
orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
# figure out which additional models are not present in multigpu clone
models_cache = {}
for mm_add_model in mm.get_additional_models():
models_cache[mm_add_model.clone_base_uuid] = mm_add_model
remove_models_uuids = set(list(models_cache.keys()))
for key, model_list in orig_additional_models.items():
for orig_add_model in model_list:
if orig_add_model.clone_base_uuid not in models_cache:
models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
existing_list = n.get_additional_models_with_key(key)
existing_list.append(models_cache[orig_add_model.clone_base_uuid])
n.set_additional_models(key, existing_list)
if orig_add_model.clone_base_uuid in remove_models_uuids:
remove_models_uuids.remove(orig_add_model.clone_base_uuid)
# remove duplicate additional models
for key, model_list in n.additional_models.items():
new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
n.set_additional_models(key, new_model_list)
for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
callback(self, n)
new_multigpu_models.append(n)
self.set_additional_models("multigpu", new_multigpu_models)
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
if allow_multigpu:
if self.clone_base_uuid != clone.clone_base_uuid:
return False
else:
if not self.is_clone(clone):
return False
def clone_has_same_weights(self, clone: 'ModelPatcher'):
if not self.is_clone(clone):
return False
if self.current_hooks != clone.current_hooks:
return False
@@ -530,6 +455,19 @@ class ModelPatcher:
def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input")
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
rope_options = self.model_options["transformer_options"].get("rope_options", {})
rope_options["scale_x"] = scale_x
rope_options["scale_y"] = scale_y
rope_options["scale_t"] = scale_t
rope_options["shift_x"] = shift_x
rope_options["shift_y"] = shift_y
rope_options["shift_t"] = shift_t
self.model_options["transformer_options"]["rope_options"] = rope_options
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
@@ -698,6 +636,21 @@ class ModelPatcher:
else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def pin_weight_to_device(self, key):
weight, set_func, convert_func = get_key_weight(self.model, key)
if comfy.model_management.pin_memory(weight):
self.pinned.add(key)
def unpin_weight(self, key):
if key in self.pinned:
weight, set_func, convert_func = get_key_weight(self.model, key)
comfy.model_management.unpin_memory(weight)
self.pinned.remove(key)
def unpin_all_weights(self):
for key in list(self.pinned):
self.unpin_weight(key)
def _load_list(self):
loading = []
for n, m in self.model.named_modules():
@@ -719,9 +672,11 @@ class ModelPatcher:
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
lowvram_mem_counter = 0
loading = self._load_list()
load_completely = []
offloaded = []
loading.sort(reverse=True)
for x in loading:
n = x[1]
@@ -738,6 +693,7 @@ class ModelPatcher:
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
lowvram_counter += 1
lowvram_mem_counter += module_mem
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
continue
@@ -763,6 +719,7 @@ class ModelPatcher:
patch_counter += 1
cast_weight = True
offloaded.append((module_mem, n, m, params))
else:
if hasattr(m, "comfy_cast_weights"):
wipe_lowvram_weight(m)
@@ -793,7 +750,9 @@ class ModelPatcher:
continue
for param in params:
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
key = "{}.{}".format(n, param)
self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
@@ -801,11 +760,17 @@ class ModelPatcher:
for x in load_completely:
x[2].to(device_to)
for x in offloaded:
n = x[1]
params = x[3]
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
else:
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False
if full_load:
self.model.to(device_to)
@@ -842,6 +807,7 @@ class ModelPatcher:
self.eject_model()
if unpatch_weights:
self.unpatch_hooks()
self.unpin_all_weights()
if self.model.model_lowvram:
for m in self.model.modules():
move_weight_functions(m, device_to)
@@ -937,6 +903,9 @@ class ModelPatcher:
memory_freed += module_mem
logging.debug("freed {}".format(n))
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed
@@ -1063,7 +1032,7 @@ class ModelPatcher:
return self.additional_models.get(key, [])
def get_additional_models(self):
all_models: list[ModelPatcher] = []
all_models = []
for models in self.additional_models.values():
all_models.extend(models)
return all_models
@@ -1117,13 +1086,9 @@ class ModelPatcher:
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
callback(self)
def prepare_state(self, timestep, model_options, ignore_multigpu=False):
def prepare_state(self, timestep):
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
callback(self, timestep, model_options, ignore_multigpu)
if not ignore_multigpu and "multigpu_clones" in model_options:
for p in model_options["multigpu_clones"].values():
p: ModelPatcher
p.prepare_state(timestep, model_options, ignore_multigpu=True)
callback(self, timestep)
def restore_hook_patches(self):
if self.hook_patches_backup is not None:
@@ -1136,18 +1101,12 @@ class ModelPatcher:
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
curr_t = t[0]
reset_current_hooks = False
multigpu_kf_changed_cache = None
transformer_options = model_options.get("transformer_options", {})
for hook in hook_group.hooks:
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
# this will cause the weights to be recalculated when sampling
if changed:
# cache changed for multigpu usage
if "multigpu_clones" in model_options:
if multigpu_kf_changed_cache is None:
multigpu_kf_changed_cache = []
multigpu_kf_changed_cache.append(hook)
# reset current_hooks if contains hook that changed
if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks:
@@ -1159,28 +1118,6 @@ class ModelPatcher:
self.cached_hook_patches.pop(cached_group)
if reset_current_hooks:
self.patch_hooks(None)
if "multigpu_clones" in model_options:
for p in model_options["multigpu_clones"].values():
p: ModelPatcher
p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
if kf_changed_cache is None:
return
reset_current_hooks = False
# reset current_hooks if contains hook that changed
for hook in kf_changed_cache:
if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks:
if current_hook == hook:
reset_current_hooks = True
break
for cached_group in list(self.cached_hook_patches.keys()):
if cached_group.contains(hook):
self.cached_hook_patches.pop(cached_group)
if reset_current_hooks:
self.patch_hooks(None)
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
registered: comfy.hooks.HookGroup = None):
@@ -1371,5 +1308,6 @@ class ModelPatcher:
self.clear_cached_hook_weights()
def __del__(self):
self.unpin_all_weights()
self.detach(unpatch_all=False)

View File

@@ -1,167 +0,0 @@
from __future__ import annotations
import torch
import logging
from collections import namedtuple
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
import comfy.utils
import comfy.patcher_extension
import comfy.model_management
class GPUOptions:
def __init__(self, device_index: int, relative_speed: float):
self.device_index = device_index
self.relative_speed = relative_speed
def clone(self):
return GPUOptions(self.device_index, self.relative_speed)
def create_dict(self):
return {
"relative_speed": self.relative_speed
}
class GPUOptionsGroup:
def __init__(self):
self.options: dict[int, GPUOptions] = {}
def add(self, info: GPUOptions):
self.options[info.device_index] = info
def clone(self):
c = GPUOptionsGroup()
for opt in self.options.values():
c.add(opt)
return c
def register(self, model: ModelPatcher):
opts_dict = {}
# get devices that are valid for this model
devices: list[torch.device] = [model.load_device]
for extra_model in model.get_additional_models_with_key("multigpu"):
extra_model: ModelPatcher
devices.append(extra_model.load_device)
# create dictionary with actual device mapped to its GPUOptions
device_opts_list: list[GPUOptions] = []
for device in devices:
device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
opts_dict[device] = device_opts.create_dict()
device_opts_list.append(device_opts)
# make relative_speed relative to 1.0
min_speed = min([x.relative_speed for x in device_opts_list])
for value in opts_dict.values():
value['relative_speed'] /= min_speed
model.model_options['multigpu_options'] = opts_dict
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
model = model.clone()
# check if multigpu is already prepared - get the load devices from them if possible to exclude
skip_devices = set()
multigpu_models = model.get_additional_models_with_key("multigpu")
if len(multigpu_models) > 0:
for mm in multigpu_models:
skip_devices.add(mm.load_device)
skip_devices = list(skip_devices)
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
limit_extra_devices = full_extra_devices[:max_gpus-1]
extra_devices = limit_extra_devices.copy()
# exclude skipped devices
for skip in skip_devices:
if skip in extra_devices:
extra_devices.remove(skip)
# create new deepclones
if len(extra_devices) > 0:
for device in extra_devices:
device_patcher = None
if reuse_loaded:
# check if there are any ModelPatchers currently loaded that could be referenced here after a clone
loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
for lm in loaded_models:
if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device:
device_patcher = lm.clone()
logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}")
break
if device_patcher is None:
device_patcher = model.deepclone_multigpu(new_load_device=device)
device_patcher.is_multigpu_base_clone = True
multigpu_models = model.get_additional_models_with_key("multigpu")
multigpu_models.append(device_patcher)
model.set_additional_models("multigpu", multigpu_models)
model.match_multigpu_clones()
if gpu_options is None:
gpu_options = GPUOptionsGroup()
gpu_options.register(model)
else:
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
# TODO: only keep model clones that don't go 'past' the intended max_gpu count
# multigpu_models = model.get_additional_models_with_key("multigpu")
# new_multigpu_models = []
# for m in multigpu_models:
# if m.load_device in limit_extra_devices:
# new_multigpu_models.append(m)
# model.set_additional_models("multigpu", new_multigpu_models)
# persist skip_devices for use in sampling code
# if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options:
# model.model_options["multigpu_skip_devices"] = skip_devices
return model
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
opts_dict = model_options['multigpu_options']
devices = list(model_options['multigpu_clones'].keys())
speed_per_device = []
work_per_device = []
# get sum of each device's relative_speed
total_speed = 0.0
for opts in opts_dict.values():
total_speed += opts['relative_speed']
# get relative work for each device;
# obtained by w = (W*r)/R
for device in devices:
relative_speed = opts_dict[device]['relative_speed']
relative_work = (total_work*relative_speed) / total_speed
speed_per_device.append(relative_speed)
work_per_device.append(relative_work)
# relative work must be expressed in whole numbers, but likely is a decimal;
# perform rounding while maintaining total sum equal to total work (sum of relative works)
work_per_device = round_preserved(work_per_device)
dict_work_per_device = {}
for device, relative_work in zip(devices, work_per_device):
dict_work_per_device[device] = relative_work
if not return_idle_time:
return LoadBalance(dict_work_per_device, None)
# divide relative work by relative speed to get estimated completion time of said work by each device;
# time here is relative and does not correspond to real-world units
completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
# calculate relative time spent by the devices waiting on each other after their work is completed
idle_time = abs(min(completion_time) - max(completion_time))
# if need to compare work idle time, need to normalize to a common total work
if work_normalized:
idle_time *= (work_normalized/total_work)
return LoadBalance(dict_work_per_device, idle_time)
def round_preserved(values: list[float]):
'Round all values in a list, preserving the combined sum of values.'
# get floor of values; casting to int does it too
floored = [int(x) for x in values]
total_floored = sum(floored)
# get remainder to distribute
remainder = round(sum(values)) - total_floored
# pair values with fractional portions
fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
# sort by fractional part in descending order
fractional.sort(key=lambda x: x[1], reverse=True)
# distribute the remainder
for i in range(remainder):
index = fractional[i][0]
floored[index] += 1
return floored

91
comfy/nested_tensor.py Normal file
View File

@@ -0,0 +1,91 @@
import torch
class NestedTensor:
def __init__(self, tensors):
self.tensors = list(tensors)
self.is_nested = True
def _copy(self):
return NestedTensor(self.tensors)
def apply_operation(self, other, operation):
o = self._copy()
if isinstance(other, NestedTensor):
for i, t in enumerate(o.tensors):
o.tensors[i] = operation(t, other.tensors[i])
else:
for i, t in enumerate(o.tensors):
o.tensors[i] = operation(t, other)
return o
def __add__(self, b):
return self.apply_operation(b, lambda x, y: x + y)
def __sub__(self, b):
return self.apply_operation(b, lambda x, y: x - y)
def __mul__(self, b):
return self.apply_operation(b, lambda x, y: x * y)
# def __itruediv__(self, b):
# return self.apply_operation(b, lambda x, y: x / y)
def __truediv__(self, b):
return self.apply_operation(b, lambda x, y: x / y)
def __getitem__(self, *args, **kwargs):
return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs))
def unbind(self):
return self.tensors
def to(self, *args, **kwargs):
o = self._copy()
for i, t in enumerate(o.tensors):
o.tensors[i] = t.to(*args, **kwargs)
return o
def new_ones(self, *args, **kwargs):
return self.tensors[0].new_ones(*args, **kwargs)
def float(self):
return self.to(dtype=torch.float)
def chunk(self, *args, **kwargs):
return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs))
def size(self):
return self.tensors[0].size()
@property
def shape(self):
return self.tensors[0].shape
@property
def ndim(self):
dims = 0
for t in self.tensors:
dims = max(t.ndim, dims)
return dims
@property
def device(self):
return self.tensors[0].device
@property
def dtype(self):
return self.tensors[0].dtype
@property
def layout(self):
return self.tensors[0].layout
def cat_nested(tensors, *args, **kwargs):
cated_tensors = []
for i in range(len(tensors[0].tensors)):
tens = []
for j in range(len(tensors)):
tens.append(tensors[j].tensors[i])
cated_tensors.append(torch.cat(tens, *args, **kwargs))
return NestedTensor(cated_tensors)

View File

@@ -25,6 +25,9 @@ import comfy.rmsnorm
import contextlib
def run_every_op():
if torch.compiler.is_compiling():
return
comfy.model_management.throw_exception_if_processing_interrupted()
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
@@ -32,7 +35,7 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
try:
if torch.cuda.is_available():
if torch.cuda.is_available() and comfy.model_management.WINDOWS:
from torch.nn.attention import SDPBackend, sdpa_kernel
import inspect
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
@@ -52,15 +55,26 @@ try:
except (ModuleNotFoundError, TypeError):
logging.warning("Could not set sdpa backend priority.")
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False
try:
if comfy.model_management.is_nvidia():
if torch.backends.cudnn.version() >= 91002 and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10):
#TODO: change upper bound version once it's fixed'
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True
logging.info("working around nvidia conv3d memory bug.")
except:
pass
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
torch.backends.cudnn.benchmark = True
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
if input is not None:
if dtype is None:
dtype = input.dtype
@@ -69,32 +83,58 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if device is None:
device = input.device
offload_stream = comfy.model_management.get_offload_stream(device)
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
offload_stream = comfy.model_management.get_offload_stream(device)
else:
offload_stream = None
if offload_stream is not None:
wf_context = offload_stream
else:
wf_context = contextlib.nullcontext()
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
has_function = len(s.bias_function) > 0
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function:
weight_has_function = len(s.weight_function) > 0
bias_has_function = len(s.bias_function) > 0
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
bias = None
if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
if bias_has_function:
with wf_context:
for f in s.bias_function:
bias = f(bias)
has_function = len(s.weight_function) > 0
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function:
weight = weight.to(dtype=dtype)
if weight_has_function:
with wf_context:
for f in s.weight_function:
weight = f(weight)
comfy.model_management.sync_stream(device, offload_stream)
return weight, bias
if offloadable:
return weight, bias, offload_stream
else:
#Legacy function signature
return weight, bias
def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
if weight is not None:
device = weight.device
else:
if bias is None:
return
device = bias.device
offload_stream.wait_stream(comfy.model_management.current_stream(device))
class CastWeightBiasOp:
comfy_cast_weights = False
@@ -107,8 +147,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -122,8 +164,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -137,8 +181,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -151,9 +197,20 @@ class disable_weight_init:
def reset_parameters(self):
return None
def _conv_forward(self, input, weight, bias, *args, **kwargs):
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
if bias is not None:
out += bias.reshape((1, -1) + (1,) * (out.ndim - 2))
return out
else:
return super()._conv_forward(input, weight, bias, *args, **kwargs)
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -167,8 +224,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -183,11 +242,14 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias = cast_bias_weight(self, input)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else:
weight = None
bias = None
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
offload_stream = None
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -203,11 +265,15 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias = cast_bias_weight(self, input)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else:
weight = None
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
bias = None
offload_stream = None
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -226,10 +292,12 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose2d(
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -248,10 +316,12 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose1d(
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -269,8 +339,11 @@ class disable_weight_init:
output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -324,20 +397,18 @@ class manual_cast(disable_weight_init):
def fp8_linear(self, input):
"""
Legacy FP8 linear function for backward compatibility.
Uses QuantizedTensor subclass for dispatch.
"""
dtype = self.weight.dtype
if dtype not in [torch.float8_e4m3fn]:
return None
tensor_2d = False
if len(input.shape) == 2:
tensor_2d = True
input = input.unsqueeze(1)
input_shape = input.shape
input_dtype = input.dtype
if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
w = w.t()
if input.ndim == 3 or input.ndim == 2:
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = self.scale_weight
scale_input = self.scale_input
@@ -349,23 +420,20 @@ def fp8_linear(self, input):
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
else:
scale_input = scale_input.to(input.device)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
if bias is not None:
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
else:
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
if isinstance(o, tuple):
o = o[0]
if tensor_2d:
return o.reshape(input_shape[0], -1)
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
uncast_bias_weight(self, w, bias, offload_stream)
return o
return None
@@ -385,8 +453,10 @@ class fp8_ops(manual_cast):
except Exception as e:
logging.info("Exception during fp8 op: {}".format(e))
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
@@ -414,12 +484,14 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
if out is not None:
return out
weight, bias = cast_bias_weight(self, input)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
if weight.numel() < input.numel(): #TODO: optimize
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
else:
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def convert_weight(self, weight, inplace=False, **kwargs):
if inplace:
@@ -458,7 +530,130 @@ if CUBLAS_IS_AVAILABLE:
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
# ==============================================================================
# Mixed Precision Operations
# ==============================================================================
from .quant_ops import QuantizedTensor
QUANT_FORMAT_MIXINS = {
"float8_e4m3fn": {
"dtype": torch.float8_e4m3fn,
"layout_type": "TensorCoreFP8Layout",
"parameters": {
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
}
}
}
class MixedPrecisionOps(disable_weight_init):
_layer_quant_config = {}
_compute_dtype = torch.bfloat16
class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
self.tensor_class = None
def reset_parameters(self):
return None
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
device = self.factory_kwargs["device"]
layer_name = prefix.rstrip('.')
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}")
manually_loaded_keys = [weight_key]
if layer_name not in MixedPrecisionOps._layer_quant_config:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
if quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
mixin = QUANT_FORMAT_MIXINS[quant_format]
self.layout_type = mixin["layout_type"]
scale_key = f"{prefix}weight_scale"
layout_params = {
'scale': state_dict.pop(scale_key, None),
'orig_dtype': MixedPrecisionOps._compute_dtype
}
if layout_params['scale'] is not None:
manually_loaded_keys.append(scale_key)
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params),
requires_grad=False
)
for param_name, param_value in mixin["parameters"].items():
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for key in manually_loaded_keys:
if key in missing_keys:
missing_keys.remove(key)
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, input, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
MixedPrecisionOps._compute_dtype = compute_dtype
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
return MixedPrecisionOps
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)

View File

@@ -3,8 +3,6 @@ from typing import Callable
class CallbacksMP:
ON_CLONE = "on_clone"
ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
ON_LOAD = "on_load_after"
ON_DETACH = "on_detach_after"
ON_CLEANUP = "on_cleanup"

512
comfy/quant_ops.py Normal file
View File

@@ -0,0 +1,512 @@
import torch
import logging
from typing import Tuple, Dict
_LAYOUT_REGISTRY = {}
_GENERIC_UTILS = {}
def register_layout_op(torch_op, layout_type):
"""
Decorator to register a layout-specific operation handler.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
layout_type: Layout class (e.g., TensorCoreFP8Layout)
Example:
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
def fp8_linear(func, args, kwargs):
# FP8-specific linear implementation
...
"""
def decorator(handler_func):
if torch_op not in _LAYOUT_REGISTRY:
_LAYOUT_REGISTRY[torch_op] = {}
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
return handler_func
return decorator
def register_generic_util(torch_op):
"""
Decorator to register a generic utility that works for all layouts.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
Example:
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
# Works for any layout
...
"""
def decorator(handler_func):
_GENERIC_UTILS[torch_op] = handler_func
return handler_func
return decorator
def _get_layout_from_args(args):
for arg in args:
if isinstance(arg, QuantizedTensor):
return arg._layout_type
elif isinstance(arg, (list, tuple)):
for item in arg:
if isinstance(item, QuantizedTensor):
return item._layout_type
return None
def _move_layout_params_to_device(params, device):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.to(device=device)
else:
new_params[k] = v
return new_params
def _copy_layout_params(params):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.clone()
else:
new_params[k] = v
return new_params
class QuantizedLayout:
"""
Base class for quantization layouts.
A layout encapsulates the format-specific logic for quantization/dequantization
and provides a uniform interface for extracting raw tensors needed for computation.
New quantization formats should subclass this and implement the required methods.
"""
@classmethod
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
@staticmethod
def dequantize(qdata, **layout_params) -> torch.Tensor:
raise NotImplementedError("TensorLayout must implement dequantize()")
@classmethod
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
class QuantizedTensor(torch.Tensor):
"""
Universal quantized tensor that works with any layout.
This tensor subclass uses a pluggable layout system to support multiple
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
The layout_type determines format-specific behavior, while common operations
(detach, clone, to) are handled generically.
Attributes:
_qdata: The quantized tensor data
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
"""
@staticmethod
def __new__(cls, qdata, layout_type, layout_params):
"""
Create a quantized tensor.
Args:
qdata: The quantized data tensor
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata
self._layout_type = layout_type
self._layout_params = layout_params
def __repr__(self):
layout_name = self._layout_type
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
@property
def layout_type(self):
return self._layout_type
def __tensor_flatten__(self):
"""
Tensor flattening protocol for proper device movement.
"""
inner_tensors = ["_qdata"]
ctx = {
"layout_type": self._layout_type,
}
tensor_params = {}
non_tensor_params = {}
for k, v in self._layout_params.items():
if isinstance(v, torch.Tensor):
tensor_params[k] = v
else:
non_tensor_params[k] = v
ctx["tensor_param_keys"] = list(tensor_params.keys())
ctx["non_tensor_params"] = non_tensor_params
for k, v in tensor_params.items():
attr_name = f"_layout_param_{k}"
object.__setattr__(self, attr_name, v)
inner_tensors.append(attr_name)
return inner_tensors, ctx
@staticmethod
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
"""
Tensor unflattening protocol for proper device movement.
Reconstructs the QuantizedTensor after device movement.
"""
layout_type = ctx["layout_type"]
layout_params = dict(ctx["non_tensor_params"])
for key in ctx["tensor_param_keys"]:
attr_name = f"_layout_param_{key}"
layout_params[key] = inner_tensors[attr_name]
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
@classmethod
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
return cls(qdata, layout_type, layout_params)
def dequantize(self) -> torch.Tensor:
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
# Step 1: Check generic utilities first (detach, clone, to, etc.)
if func in _GENERIC_UTILS:
return _GENERIC_UTILS[func](func, args, kwargs)
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
layout_type = _get_layout_from_args(args)
if layout_type and func in _LAYOUT_REGISTRY:
handler = _LAYOUT_REGISTRY[func].get(layout_type)
if handler:
return handler(func, args, kwargs)
# Step 3: Fallback to dequantization
if isinstance(args[0] if args else None, QuantizedTensor):
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
return cls._dequant_and_fallback(func, args, kwargs)
@classmethod
def _dequant_and_fallback(cls, func, args, kwargs):
def dequant_arg(arg):
if isinstance(arg, QuantizedTensor):
return arg.dequantize()
elif isinstance(arg, (list, tuple)):
return type(arg)(dequant_arg(a) for a in arg)
return arg
new_args = dequant_arg(args)
new_kwargs = dequant_arg(kwargs)
return func(*new_args, **new_kwargs)
# ==============================================================================
# Generic Utilities (Layout-Agnostic Operations)
# ==============================================================================
def _create_transformed_qtensor(qt, transform_fn):
new_data = transform_fn(qt._qdata)
new_params = _copy_layout_params(qt._layout_params)
return QuantizedTensor(new_data, qt._layout_type, new_params)
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
if target_dtype is not None and target_dtype != qt.dtype:
logging.warning(
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
f"but not supported for quantized tensors. Ignoring dtype."
)
if target_layout is not None and target_layout != torch.strided:
logging.warning(
f"QuantizedTensor: layout change requested to {target_layout}, "
f"but not supported. Ignoring layout."
)
# Handle device transfer
current_device = qt._qdata.device
if target_device is not None:
# Normalize device for comparison
if isinstance(target_device, str):
target_device = torch.device(target_device)
if isinstance(current_device, str):
current_device = torch.device(current_device)
if target_device != current_device:
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
new_q_data = qt._qdata.to(device=target_device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
return new_qt
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
return qt
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
"""Detach operation - creates a detached copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.detach())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.clone.default)
def generic_clone(func, args, kwargs):
"""Clone operation - creates a deep copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.clone())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._to_copy.default)
def generic_to_copy(func, args, kwargs):
"""Device/dtype transfer operation - handles .to(device) calls."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
op_name="_to_copy"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.to.dtype_layout)
def generic_to_dtype_layout(func, args, kwargs):
"""Handle .to(device) calls using the dtype_layout variant."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
target_layout=kwargs.get('layout', None),
op_name="to"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.copy_.default)
def generic_copy_(func, args, kwargs):
qt_dest = args[0]
src = args[1]
if isinstance(qt_dest, QuantizedTensor):
if isinstance(src, QuantizedTensor):
# Copy from another quantized tensor
qt_dest._qdata.copy_(src._qdata)
qt_dest._layout_type = src._layout_type
qt_dest._layout_params = _copy_layout_params(src._layout_params)
else:
# Copy from regular tensor - just copy raw data
qt_dest._qdata.copy_(src)
return qt_dest
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
return True
# ==============================================================================
# FP8 Layout + Operation Handlers
# ==============================================================================
class TensorCoreFP8Layout(QuantizedLayout):
"""
Storage format:
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
- scale: Scalar tensor (float32) for dequantization
- orig_dtype: Original dtype before quantization (for casting back)
"""
@classmethod
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):
orig_dtype = tensor.dtype
if scale is None:
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
# lp_amax = torch.finfo(dtype).max
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
layout_params = {
'scale': scale,
'orig_dtype': orig_dtype
}
return qdata, layout_params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
return plain_tensor * scale
@classmethod
def get_plain_tensors(cls, qtensor):
return qtensor._qdata, qtensor._layout_params['scale']
LAYOUTS = {
"TensorCoreFP8Layout": TensorCoreFP8Layout,
}
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
def fp8_linear(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
out_dtype = kwargs.get("out_dtype")
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
weight_t = plain_weight.t()
tensor_2d = False
if len(plain_input.shape) == 2:
tensor_2d = True
plain_input = plain_input.unsqueeze(1)
input_shape = plain_input.shape
if len(input_shape) != 3:
return None
try:
output = torch._scaled_mm(
plain_input.reshape(-1, input_shape[2]).contiguous(),
weight_t,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
if not tensor_2d:
output = output.reshape((-1, input_shape[1], weight.shape[0]))
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
output_scale = scale_a * scale_b
output_params = {
'scale': output_scale,
'orig_dtype': input_tensor._layout_params['orig_dtype']
}
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
else:
return output
except Exception as e:
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
# Case 2: DQ Fallback
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
if isinstance(input_tensor, QuantizedTensor):
input_tensor = input_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight, bias)
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
output = torch._scaled_mm(
plain_input.contiguous(),
plain_weight,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
return output
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
def fp8_addmm(func, args, kwargs):
input_tensor = args[1]
weight = args[2]
bias = args[0]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
if isinstance(args[2], QuantizedTensor):
a[2] = args[2].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
def fp8_mm(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
def fp8_func(func, args, kwargs):
input_tensor = args[0]
if isinstance(input_tensor, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
ar = list(args)
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)

View File

@@ -4,13 +4,9 @@ import comfy.samplers
import comfy.utils
import numpy as np
import logging
import comfy.nested_tensor
def prepare_noise(latent_image, seed, noise_inds=None):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
generator = torch.manual_seed(seed)
def prepare_noise_inner(latent_image, generator, noise_inds=None):
if noise_inds is None:
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
@@ -21,10 +17,29 @@ def prepare_noise(latent_image, seed, noise_inds=None):
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
noises = torch.cat(noises, axis=0)
return torch.cat(noises, axis=0)
def prepare_noise(latent_image, seed, noise_inds=None):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
generator = torch.manual_seed(seed)
if latent_image.is_nested:
tensors = latent_image.unbind()
noises = []
for t in tensors:
noises.append(prepare_noise_inner(t, generator, noise_inds))
noises = comfy.nested_tensor.NestedTensor(noises)
else:
noises = prepare_noise_inner(latent_image, generator, noise_inds)
return noises
def fix_empty_latent_channels(model, latent_image):
if latent_image.is_nested:
return latent_image
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)

View File

@@ -1,17 +1,16 @@
from __future__ import annotations
import torch
import uuid
import math
import collections
import comfy.model_management
import comfy.conds
import comfy.model_patcher
import comfy.utils
import comfy.hooks
import comfy.patcher_extension
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from comfy.controlnet import ControlBase
def prepare_mask(noise_mask, shape, device):
@@ -107,47 +106,6 @@ def cleanup_additional_models(models):
if hasattr(m, 'cleanup'):
m.cleanup()
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
if len(multigpu_models) == 0:
return
extra_devices = [x.load_device for x in multigpu_models]
# handle controlnets
controlnets: set[ControlBase] = set()
for k in conds:
for kk in conds[k]:
if 'control' in kk:
controlnets.add(kk['control'])
if len(controlnets) > 0:
# first, unload all controlnet clones
for cnet in list(controlnets):
cnet_models = cnet.get_models()
for cm in cnet_models:
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
# next, make sure each controlnet has a deepclone for all relevant devices
for cnet in controlnets:
curr_cnet = cnet
while curr_cnet is not None:
for device in extra_devices:
if device not in curr_cnet.multigpu_clones:
curr_cnet.deepclone_multigpu(device, autoregister=True)
curr_cnet = curr_cnet.previous_controlnet
# since all device clones are now present, recreate the linked list for cloned cnets per device
for cnet in controlnets:
curr_cnet = cnet
while curr_cnet is not None:
prev_cnet = curr_cnet.previous_controlnet
for device in extra_devices:
device_cnet = curr_cnet.get_instance_for_device(device)
prev_device_cnet = None
if prev_cnet is not None:
prev_device_cnet = prev_cnet.get_instance_for_device(device)
device_cnet.set_previous_controlnet(prev_device_cnet)
curr_cnet = prev_cnet
# potentially handle gligen - since not widely used, ignored for now
def estimate_memory(model, noise_shape, conds):
cond_shapes = collections.defaultdict(list)
cond_shapes_min = {}
@@ -172,8 +130,7 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None
return executor.execute(model, noise_shape, conds, model_options=model_options)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
model.match_multigpu_clones()
preprocess_multigpu_conds(conds, model, model_options)
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
@@ -225,18 +182,3 @@ def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
copy_dict1=False)
return to_load_options
def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
'''
In case multigpu acceleration is enabled, prep ModelPatchers for each device.
'''
multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
if len(multigpu_patchers) > 0:
multigpu_dict: dict[torch.device, ModelPatcher] = {}
multigpu_dict[model_patcher.load_device] = model_patcher
for x in multigpu_patchers:
x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
multigpu_dict[x.load_device] = x
model_options["multigpu_clones"] = multigpu_dict
return multigpu_patchers

View File

@@ -1,9 +1,7 @@
from __future__ import annotations
import comfy.model_management
from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc
from typing import TYPE_CHECKING, Callable, NamedTuple, Any
from typing import TYPE_CHECKING, Callable, NamedTuple
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
@@ -22,7 +20,6 @@ import comfy.context_windows
import comfy.utils
import scipy.stats
import numpy
import threading
def add_area_dims(area, num_dims):
@@ -145,7 +142,7 @@ def can_concat_cond(c1, c2):
return cond_equal_size(c1.conditioning, c2.conditioning)
def cond_cat(c_list, device=None):
def cond_cat(c_list):
temp = {}
for x in c_list:
for k in x:
@@ -157,8 +154,6 @@ def cond_cat(c_list, device=None):
for k in temp:
conds = temp[k]
out[k] = conds[0].concat(conds[1:])
if device is not None and hasattr(out[k], 'to'):
out[k] = out[k].to(device)
return out
@@ -218,9 +213,7 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
)
return executor.execute(model, conds, x_in, timestep, model_options)
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
if 'multigpu_clones' in model_options:
return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
out_conds = []
out_counts = []
# separate conds by matching hooks
@@ -252,7 +245,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
model.current_patcher.prepare_state(timestep, model_options)
model.current_patcher.prepare_state(timestep)
# run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items():
@@ -353,196 +346,6 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
return out_conds
def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
out_conds = []
out_counts = []
# separate conds by matching hooks
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
default_conds = []
has_default_conds = False
output_device = x_in.device
for i in range(len(conds)):
out_conds.append(torch.zeros_like(x_in))
out_counts.append(torch.ones_like(x_in) * 1e-37)
cond = conds[i]
default_c = []
if cond is not None:
for x in cond:
if 'default' in x:
default_c.append(x)
has_default_conds = True
continue
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
default_conds.append(default_c)
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
model.current_patcher.prepare_state(timestep, model_options)
devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()]
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
total_conds = 0
for to_run in hooked_to_run.values():
total_conds += len(to_run)
conds_per_device = max(1, math.ceil(total_conds//len(devices)))
index_device = 0
current_device = devices[index_device]
# run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items():
while len(to_run) > 0:
current_device = devices[index_device % len(devices)]
batched_to_run = device_batched_hooked_to_run.setdefault(current_device, [])
# keep track of conds currently scheduled onto this device
batched_to_run_length = 0
for btr in batched_to_run:
batched_to_run_length += len(btr[1])
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
# make sure not over conds_per_device limit when creating temp batch
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < (conds_per_device - batched_to_run_length):
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = model_management.get_free_memory(current_device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) * 1.5 < free_memory:
to_batch = batch_amount
break
conds_to_batch = []
for x in to_batch:
conds_to_batch.append(to_run.pop(x))
batched_to_run_length += len(conds_to_batch)
batched_to_run.append((hooks, conds_to_batch))
if batched_to_run_length >= conds_per_device:
index_device += 1
class thread_result(NamedTuple):
output: Any
mult: Any
area: Any
batch_chunks: int
cond_or_uncond: Any
error: Exception = None
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
try:
model_current: BaseModel = model_options["multigpu_clones"][device].model
# run every hooked_to_run separately
with torch.no_grad():
for hooks, to_batch in batch_tuple:
input_x = []
mult = []
c = []
cond_or_uncond = []
uuids = []
area = []
control: ControlBase = None
patches = None
for x in to_batch:
o = x
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
uuids.append(p.uuid)
control = p.control
patches = p.patches
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x).to(device)
c = cond_cat(c, device=device)
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
if 'transformer_options' in model_options:
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
model_options['transformer_options'],
copy_dict1=False)
if patches is not None:
transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
transformer_options.get("patches", {}),
patches
)
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["uuids"] = uuids[:]
transformer_options["sigmas"] = timestep
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
transformer_options["multigpu_thread_device"] = device
cast_transformer_options(transformer_options, device=device)
c['transformer_options'] = transformer_options
if control is not None:
device_control = control.get_instance_for_device(device)
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
else:
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
except Exception as e:
results.append(thread_result(None, None, None, None, None, error=e))
raise
results: list[thread_result] = []
threads: list[threading.Thread] = []
for device, batch_tuple in device_batched_hooked_to_run.items():
new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results))
threads.append(new_thread)
new_thread.start()
for thread in threads:
thread.join()
for output, mult, area, batch_chunks, cond_or_uncond, error in results:
if error is not None:
raise error
for o in range(batch_chunks):
cond_index = cond_or_uncond[o]
a = area[o]
if a is None:
out_conds[cond_index] += output[o] * mult[o]
out_counts[cond_index] += mult[o]
else:
out_c = out_conds[cond_index]
out_cts = out_counts[cond_index]
dims = len(a) // 2
for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += output[o] * mult[o]
out_cts += mult[o]
for i in range(len(out_conds)):
out_conds[i] /= out_counts[i]
return out_conds
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
@@ -847,8 +650,6 @@ def pre_run_control(model, conds):
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x:
x['control'].pre_run(model, percent_to_timestep_function)
for device_cnet in x['control'].multigpu_clones.values():
device_cnet.pre_run(model, percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = []
@@ -981,7 +782,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
return KSAMPLER(sampler_function, extra_options, inpaint_options)
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None, latent_shapes=None):
for k in conds:
conds[k] = conds[k][:]
resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
@@ -991,7 +792,7 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
if hasattr(model, 'extra_conds'):
for k in conds:
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes)
#make sure each cond area has an opposite one with the same area
for k in conds:
@@ -1091,9 +892,7 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
to_load_options = model_options.get("to_load_options", None)
if to_load_options is None:
return
cast_transformer_options(to_load_options, device, dtype)
def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
casts = []
if device is not None:
casts.append(device)
@@ -1102,17 +901,18 @@ def cast_transformer_options(transformer_options: dict[str], device=None, dtype=
# if nothing to apply, do nothing
if len(casts) == 0:
return
# try to call .to on patches
if "patches" in transformer_options:
patches = transformer_options["patches"]
if "patches" in to_load_options:
patches = to_load_options["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
for cast in casts:
patch_list[i] = patch_list[i].to(cast)
if "patches_replace" in transformer_options:
patches = transformer_options["patches_replace"]
if "patches_replace" in to_load_options:
patches = to_load_options["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
@@ -1122,8 +922,8 @@ def cast_transformer_options(transformer_options: dict[str], device=None, dtype=
# try to call .to on any wrappers/callbacks
wrappers_and_callbacks = ["wrappers", "callbacks"]
for wc_name in wrappers_and_callbacks:
if wc_name in transformer_options:
wc: dict[str, list] = transformer_options[wc_name]
if wc_name in to_load_options:
wc: dict[str, list] = to_load_options[wc_name]
for wc_dict in wc.values():
for wc_list in wc_dict.values():
for i in range(len(wc_list)):
@@ -1131,6 +931,7 @@ def cast_transformer_options(transformer_options: dict[str], device=None, dtype=
for cast in casts:
wc_list[i] = wc_list[i].to(cast)
class CFGGuider:
def __init__(self, model_patcher: ModelPatcher):
self.model_patcher = model_patcher
@@ -1161,11 +962,11 @@ class CFGGuider:
def predict_noise(self, x, timestep, model_options={}, seed=None):
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None):
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = self.inner_model.process_latent_in(latent_image)
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes)
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
@@ -1179,12 +980,10 @@ class CFGGuider:
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32))
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None):
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
if denoise_mask is not None:
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
@@ -1195,13 +994,9 @@ class CFGGuider:
try:
self.model_patcher.pre_run()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
self.model_patcher.cleanup()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model
@@ -1212,6 +1007,12 @@ class CFGGuider:
if sigmas.shape[-1] == 0:
return latent_image
if latent_image.is_nested:
latent_image, latent_shapes = comfy.utils.pack_latents(latent_image.unbind())
noise, _ = comfy.utils.pack_latents(noise.unbind())
else:
latent_shapes = [latent_image.shape]
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
@@ -1231,7 +1032,7 @@ class CFGGuider:
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
self.model_options = orig_model_options
@@ -1239,6 +1040,9 @@ class CFGGuider:
self.model_patcher.restore_hook_patches()
del self.conds
if len(latent_shapes) > 1:
output = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(output, latent_shapes))
return output

View File

@@ -143,6 +143,9 @@ class CLIP:
n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n
def get_ram_usage(self):
return self.patcher.get_ram_usage()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model)
@@ -293,6 +296,7 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
self.not_video = False
self.size = None
self.downscale_index_formula = None
self.upscale_index_formula = None
@@ -595,6 +599,16 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
self.model_size()
def model_size(self):
if self.size is not None:
return self.size
self.size = comfy.model_management.module_size(self.first_stage_model)
return self.size
def get_ram_usage(self):
return self.model_size()
def throw_exception_if_invalid(self):
if self.first_stage_model is None:
@@ -1262,7 +1276,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options={}):
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
"""
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
@@ -1296,7 +1310,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
weight_dtype = comfy.utils.weight_dtype(sd)
load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, "")
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
if model_config is not None:
new_sd = sd
@@ -1330,7 +1344,10 @@ def load_diffusion_model_state_dict(sd, model_options={}):
else:
unet_dtype = dtype
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
if model_config.layer_quant_config is not None:
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
if model_options.get("fp8_optimizations", False):
@@ -1346,8 +1363,8 @@ def load_diffusion_model_state_dict(sd, model_options={}):
def load_diffusion_model(unet_path, model_options={}):
sd = comfy.utils.load_torch_file(unet_path)
model = load_diffusion_model_state_dict(sd, model_options=model_options)
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
if model is None:
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))

View File

@@ -50,6 +50,7 @@ class BASE:
manual_cast_dtype = None
custom_operations = None
scaled_fp8 = None
layer_quant_config = None # Per-layer quantization configuration for mixed precision
optimizations = {"fp8": False}
@classmethod

View File

@@ -1106,3 +1106,25 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
dim=1
)
return out
def pack_latents(latents):
latent_shapes = []
tensors = []
for tensor in latents:
latent_shapes.append(tensor.shape)
tensors.append(tensor.reshape(tensor.shape[0], 1, -1))
latent = torch.cat(tensors, dim=-1)
return latent, latent_shapes
def unpack_latents(combined_latent, latent_shapes):
if len(latent_shapes) > 1:
output_tensors = []
for shape in latent_shapes:
cut = math.prod(shape[1:])
tens = combined_latent[:, :, :cut]
combined_latent = combined_latent[:, :, cut:]
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
else:
output_tensors = combined_latent
return output_tensors

View File

@@ -1,718 +0,0 @@
from __future__ import annotations
import aiohttp
import io
import logging
import mimetypes
import os
from typing import Optional, Union
from comfy.utils import common_upscale
from comfy_api.input_impl import VideoFromFile
from comfy_api.util import VideoContainer, VideoCodec
from comfy_api.input.video_types import VideoInput
from comfy_api.input.basic_types import AudioInput
from comfy_api_nodes.apis.client import (
ApiClient,
ApiEndpoint,
HttpMethod,
SynchronousOperation,
UploadRequest,
UploadResponse,
)
from server import PromptServer
from comfy.cli_args import args
import numpy as np
from PIL import Image
import torch
import math
import base64
import uuid
from io import BytesIO
import av
async def download_url_to_video_output(
video_url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
) -> VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output.
Args:
video_url: The URL of the video to download.
Returns:
A Comfy node `VIDEO` output.
"""
video_io = await download_url_to_bytesio(video_url, timeout, auth_kwargs=auth_kwargs)
if video_io is None:
error_msg = f"Failed to download video from {video_url}"
logging.error(error_msg)
raise ValueError(error_msg)
return VideoFromFile(video_io)
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
"""Downscale input image tensor to roughly the specified total pixels."""
samples = image.movedim(-1, 1)
total = int(total_pixels)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
if scale_by >= 1:
return image
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = common_upscale(samples, width, height, "lanczos", "disabled")
s = s.movedim(1, -1)
return s
async def validate_and_cast_response(
response, timeout: int = None, node_id: Union[str, None] = None
) -> torch.Tensor:
"""Validates and casts a response to a torch.Tensor.
Args:
response: The response to validate and cast.
timeout: Request timeout in seconds. Defaults to None (no timeout).
Returns:
A torch.Tensor representing the image (1, H, W, C).
Raises:
ValueError: If the response is not valid.
"""
# validate raw JSON response
data = response.data
if not data or len(data) == 0:
raise ValueError("No images returned from API endpoint")
# Initialize list to store image tensors
image_tensors: list[torch.Tensor] = []
# Process each image in the data array
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
for img_data in data:
img_bytes: bytes
if img_data.b64_json:
img_bytes = base64.b64decode(img_data.b64_json)
elif img_data.url:
if node_id:
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
async with session.get(img_data.url) as resp:
if resp.status != 200:
raise ValueError("Failed to download generated image")
img_bytes = await resp.read()
else:
raise ValueError("Invalid image payload neither URL nor base64 data present.")
pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
arr = np.asarray(pil_img).astype(np.float32) / 255.0
image_tensors.append(torch.from_numpy(arr))
return torch.stack(image_tensors, dim=0)
def validate_aspect_ratio(
aspect_ratio: str,
minimum_ratio: float,
maximum_ratio: float,
minimum_ratio_str: str,
maximum_ratio_str: str,
) -> float:
"""Validates and casts an aspect ratio string to a float.
Args:
aspect_ratio: The aspect ratio string to validate.
minimum_ratio: The minimum aspect ratio.
maximum_ratio: The maximum aspect ratio.
minimum_ratio_str: The minimum aspect ratio string.
maximum_ratio_str: The maximum aspect ratio string.
Returns:
The validated and cast aspect ratio.
Raises:
Exception: If the aspect ratio is not valid.
"""
# get ratio values
numbers = aspect_ratio.split(":")
if len(numbers) != 2:
raise TypeError(
f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}."
)
try:
numerator = int(numbers[0])
denominator = int(numbers[1])
except ValueError as exc:
raise TypeError(
f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}."
) from exc
calculated_ratio = numerator / denominator
# if not close to minimum and maximum, check bounds
if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose(
calculated_ratio, maximum_ratio
):
if calculated_ratio < minimum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
if calculated_ratio > maximum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
return aspect_ratio
def mimetype_to_extension(mime_type: str) -> str:
"""Converts a MIME type to a file extension."""
return mime_type.split("/")[-1].lower()
async def download_url_to_bytesio(
url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
) -> BytesIO:
"""Downloads content from a URL using requests and returns it as BytesIO.
Args:
url: The URL to download.
timeout: Request timeout in seconds. Defaults to None (no timeout).
Returns:
BytesIO object containing the downloaded content.
"""
headers = {}
if url.startswith("/proxy/"):
url = str(args.comfy_api_base).rstrip("/") + url
auth_token = auth_kwargs.get("auth_token")
comfy_api_key = auth_kwargs.get("comfy_api_key")
if auth_token:
headers["Authorization"] = f"Bearer {auth_token}"
elif comfy_api_key:
headers["X-API-KEY"] = comfy_api_key
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
async with session.get(url, headers=headers) as resp:
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
return BytesIO(await resp.read())
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
"""Converts image data from BytesIO to a torch.Tensor.
Args:
image_bytesio: BytesIO object containing the image data.
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
Returns:
A torch.Tensor representing the image (1, H, W, C).
Raises:
PIL.UnidentifiedImageError: If the image data cannot be identified.
ValueError: If the specified mode is invalid.
"""
image = Image.open(image_bytesio)
image = image.convert(mode)
image_array = np.array(image).astype(np.float32) / 255.0
return torch.from_numpy(image_array).unsqueeze(0)
async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
image_bytesio = await download_url_to_bytesio(url, timeout)
return bytesio_to_image_tensor(image_bytesio)
def process_image_response(response_content: bytes | str) -> torch.Tensor:
"""Uses content from a Response object and converts it to a torch.Tensor"""
return bytesio_to_image_tensor(BytesIO(response_content))
def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
if len(image.shape) > 3:
image = image[0]
# TODO: remove alpha if not allowed and present
input_tensor = image.cpu()
input_tensor = downscale_image_tensor(
input_tensor.unsqueeze(0), total_pixels=total_pixels
).squeeze()
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
return img
def _pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
"""Converts a PIL Image to a BytesIO object."""
if not mime_type:
mime_type = "image/png"
img_byte_arr = io.BytesIO()
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
pil_format = mime_type.split("/")[-1].upper()
if pil_format == "JPG":
pil_format = "JPEG"
img.save(img_byte_arr, format=pil_format)
img_byte_arr.seek(0)
return img_byte_arr
def tensor_to_bytesio(
image: torch.Tensor,
name: Optional[str] = None,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> BytesIO:
"""Converts a torch.Tensor image to a named BytesIO object.
Args:
image: Input torch.Tensor image.
name: Optional filename for the BytesIO object.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns:
Named BytesIO object containing the image data, with pointer set to the start of buffer.
"""
if not mime_type:
mime_type = "image/png"
pil_image = _tensor_to_pil(image, total_pixels=total_pixels)
img_binary = _pil_to_bytesio(pil_image, mime_type=mime_type)
img_binary.name = (
f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
)
return img_binary
def tensor_to_base64_string(
image_tensor: torch.Tensor,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> str:
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
Args:
image_tensor: Input torch.Tensor image.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns:
Base64 encoded string of the image.
"""
pil_image = _tensor_to_pil(image_tensor, total_pixels=total_pixels)
img_byte_arr = _pil_to_bytesio(pil_image, mime_type=mime_type)
img_bytes = img_byte_arr.getvalue()
# Encode bytes to base64 string
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
return base64_encoded_string
def tensor_to_data_uri(
image_tensor: torch.Tensor,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> str:
"""Converts a tensor image to a Data URI string.
Args:
image_tensor: Input torch.Tensor image.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
Returns:
Data URI string (e.g., 'data:image/png;base64,...').
"""
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
return f"data:{mime_type};base64,{base64_string}"
def text_filepath_to_base64_string(filepath: str) -> str:
"""Converts a text file to a base64 string."""
with open(filepath, "rb") as f:
file_content = f.read()
return base64.b64encode(file_content).decode("utf-8")
def text_filepath_to_data_uri(filepath: str) -> str:
"""Converts a text file to a data URI."""
base64_string = text_filepath_to_base64_string(filepath)
mime_type, _ = mimetypes.guess_type(filepath)
if mime_type is None:
mime_type = "application/octet-stream"
return f"data:{mime_type};base64,{base64_string}"
async def upload_file_to_comfyapi(
file_bytes_io: BytesIO,
filename: str,
upload_mime_type: Optional[str],
auth_kwargs: Optional[dict[str, str]] = None,
) -> str:
"""
Uploads a single file to ComfyUI API and returns its download URL.
Args:
file_bytes_io: BytesIO object containing the file data.
filename: The filename of the file.
upload_mime_type: MIME type of the file.
auth_kwargs: Optional authentication token(s).
Returns:
The download URL for the uploaded file.
"""
if upload_mime_type is None:
request_object = UploadRequest(file_name=filename)
else:
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/customers/storage",
method=HttpMethod.POST,
request_model=UploadRequest,
response_model=UploadResponse,
),
request=request_object,
auth_kwargs=auth_kwargs,
)
response: UploadResponse = await operation.execute()
await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type)
return response.download_url
def video_to_base64_string(
video: VideoInput,
container_format: VideoContainer = None,
codec: VideoCodec = None
) -> str:
"""
Converts a video input to a base64 string.
Args:
video: The video input to convert
container_format: Optional container format to use (defaults to video.container if available)
codec: Optional codec to use (defaults to video.codec if available)
"""
video_bytes_io = io.BytesIO()
# Use provided format/codec if specified, otherwise use video's own if available
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
video_bytes_io.seek(0)
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
async def upload_video_to_comfyapi(
video: VideoInput,
auth_kwargs: Optional[dict[str, str]] = None,
container: VideoContainer = VideoContainer.MP4,
codec: VideoCodec = VideoCodec.H264,
max_duration: Optional[int] = None,
) -> str:
"""
Uploads a single video to ComfyUI API and returns its download URL.
Uses the specified container and codec for saving the video before upload.
Args:
video: VideoInput object (Comfy VIDEO type).
auth_kwargs: Optional authentication token(s).
container: The video container format to use (default: MP4).
codec: The video codec to use (default: H264).
max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised.
Returns:
The download URL for the uploaded video file.
"""
if max_duration is not None:
try:
actual_duration = video.duration_seconds
if actual_duration is not None and actual_duration > max_duration:
raise ValueError(
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
)
except Exception as e:
logging.error("Error getting video duration: %s", str(e))
raise ValueError(f"Could not verify video duration from source: {e}") from e
upload_mime_type = f"video/{container.value.lower()}"
filename = f"uploaded_video.{container.value.lower()}"
# Convert VideoInput to BytesIO using specified container/codec
video_bytes_io = io.BytesIO()
video.save_to(video_bytes_io, format=container, codec=codec)
video_bytes_io.seek(0)
return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs)
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
"""
Prepares audio waveform for av library by converting to a contiguous numpy array.
Args:
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
Returns:
Contiguous numpy array of the audio waveform. If the audio was batched,
the first item is taken.
"""
if waveform.ndim != 3 or waveform.shape[0] != 1:
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
# If batch is > 1, take first item
if waveform.shape[0] > 1:
waveform = waveform[0]
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
if audio_data_np.dtype != np.float32:
audio_data_np = audio_data_np.astype(np.float32)
return audio_data_np
def audio_ndarray_to_bytesio(
audio_data_np: np.ndarray,
sample_rate: int,
container_format: str = "mp4",
codec_name: str = "aac",
) -> BytesIO:
"""
Encodes a numpy array of audio data into a BytesIO object.
"""
audio_bytes_io = io.BytesIO()
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
frame = av.AudioFrame.from_ndarray(
audio_data_np,
format="fltp",
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
)
frame.sample_rate = sample_rate
frame.pts = 0
for packet in audio_stream.encode(frame):
output_container.mux(packet)
# Flush stream
for packet in audio_stream.encode(None):
output_container.mux(packet)
audio_bytes_io.seek(0)
return audio_bytes_io
async def upload_audio_to_comfyapi(
audio: AudioInput,
auth_kwargs: Optional[dict[str, str]] = None,
container_format: str = "mp4",
codec_name: str = "aac",
mime_type: str = "audio/mp4",
filename: str = "uploaded_audio.mp4",
) -> str:
"""
Uploads a single audio input to ComfyUI API and returns its download URL.
Encodes the raw waveform into the specified format before uploading.
Args:
audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate)
auth_kwargs: Optional authentication token(s).
Returns:
The download URL for the uploaded audio file.
"""
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(
audio_data_np, sample_rate, container_format, codec_name
)
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
if wav.dtype.is_floating_point:
return wav
elif wav.dtype == torch.int16:
return wav.float() / (2 ** 15)
elif wav.dtype == torch.int32:
return wav.float() / (2 ** 31)
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict:
"""
Decode any common audio container from bytes using PyAV and return
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
"""
with av.open(io.BytesIO(audio_bytes)) as af:
if not af.streams.audio:
raise ValueError("No audio stream found in response.")
stream = af.streams.audio[0]
in_sr = int(stream.codec_context.sample_rate)
out_sr = in_sr
frames: list[torch.Tensor] = []
n_channels = stream.channels or 1
for frame in af.decode(streams=stream.index):
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
buf = torch.from_numpy(arr)
if buf.ndim == 1:
buf = buf.unsqueeze(0) # [T] -> [1, T]
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
elif buf.shape[0] != n_channels:
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
frames.append(buf)
if not frames:
raise ValueError("Decoded zero audio frames.")
wav = torch.cat(frames, dim=1) # [C, T]
wav = f32_pcm(wav)
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO:
waveform = audio["waveform"].cpu()
output_buffer = io.BytesIO()
output_container = av.open(output_buffer, mode='w', format="mp3")
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
out_stream.bit_rate = 320000
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
frame.sample_rate = audio["sample_rate"]
frame.pts = 0
output_container.mux(out_stream.encode(frame))
output_container.mux(out_stream.encode(None))
output_container.close()
output_buffer.seek(0)
return output_buffer
def audio_to_base64_string(
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
) -> str:
"""Converts an audio input to a base64 string."""
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(
audio_data_np, sample_rate, container_format, codec_name
)
audio_bytes = audio_bytes_io.getvalue()
return base64.b64encode(audio_bytes).decode("utf-8")
async def upload_images_to_comfyapi(
image: torch.Tensor,
max_images=8,
auth_kwargs: Optional[dict[str, str]] = None,
mime_type: Optional[str] = None,
) -> list[str]:
"""
Uploads images to ComfyUI API and returns download URLs.
To upload multiple images, stack them in the batch dimension first.
Args:
image: Input torch.Tensor image.
max_images: Maximum number of images to upload.
auth_kwargs: Optional authentication token(s).
mime_type: Optional MIME type for the image.
"""
# if batch, try to upload each file if max_images is greater than 0
download_urls: list[str] = []
is_batch = len(image.shape) > 3
batch_len = image.shape[0] if is_batch else 1
for idx in range(min(batch_len, max_images)):
tensor = image[idx] if is_batch else image
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs)
download_urls.append(url)
return download_urls
def resize_mask_to_image(
mask: torch.Tensor,
image: torch.Tensor,
upscale_method="nearest-exact",
crop="disabled",
allow_gradient=True,
add_channel_dim=False,
):
"""
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
"""
_, H, W, _ = image.shape
mask = mask.unsqueeze(-1)
mask = mask.movedim(-1, 1)
mask = common_upscale(
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
)
mask = mask.movedim(1, -1)
if not add_channel_dim:
mask = mask.squeeze(-1)
if not allow_gradient:
mask = (mask > 0.5).float()
return mask
def validate_string(
string: str,
strip_whitespace=True,
field_name="prompt",
min_length=None,
max_length=None,
):
if string is None:
raise Exception(f"Field '{field_name}' cannot be empty.")
if strip_whitespace:
string = string.strip()
if min_length and len(string) < min_length:
raise Exception(
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
)
if max_length and len(string) > max_length:
raise Exception(
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
)
def image_tensor_pair_to_batch(
image1: torch.Tensor, image2: torch.Tensor
) -> torch.Tensor:
"""
Converts a pair of image tensors to a batch tensor.
If the images are not the same size, the smaller image is resized to
match the larger image.
"""
if image1.shape[1:] != image2.shape[1:]:
image2 = common_upscale(
image2.movedim(-1, 1),
image1.shape[2],
image1.shape[1],
"bilinear",
"center",
).movedim(1, -1)
return torch.cat((image1, image2), dim=0)
def get_size(path_or_object: Union[str, io.BytesIO]) -> int:
if isinstance(path_or_object, str):
return os.path.getsize(path_or_object)
return len(path_or_object.getvalue())
def validate_container_format_is_mp4(video: VideoInput) -> None:
"""Validates video container format is MP4."""
container_format = video.get_container_format()
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")

View File

@@ -50,44 +50,6 @@ class BFLFluxFillImageRequest(BaseModel):
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
class BFLFluxCannyImageRequest(BaseModel):
prompt: str = Field(..., description='Text prompt for image generation')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
canny_low_threshold: Optional[int] = Field(None, description='Low threshold for Canny edge detection')
canny_high_threshold: Optional[int] = Field(None, description='High threshold for Canny edge detection')
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
class BFLFluxDepthImageRequest(BaseModel):
prompt: str = Field(..., description='Text prompt for image generation')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
class BFLFluxProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.')
prompt_upsampling: Optional[bool] = Field(
@@ -160,15 +122,8 @@ class BFLStatus(str, Enum):
error = "Error"
class BFLFluxProStatusResponse(BaseModel):
class BFLFluxStatusResponse(BaseModel):
id: str = Field(..., description="The unique identifier for the generation task.")
status: BFLStatus = Field(..., description="The status of the task.")
result: Optional[Dict[str, Any]] = Field(
None, description="The result of the task (null if not completed)."
)
progress: confloat(ge=0.0, le=1.0) = Field(
..., description="The progress of the task (0.0 to 1.0)."
)
details: Optional[Dict[str, Any]] = Field(
None, description="Additional details about the task (null if not available)."
)
result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).")
progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0)

View File

@@ -0,0 +1,120 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
class MinimaxBaseResponse(BaseModel):
status_code: int = Field(
...,
description='Status code. 0 indicates success, other values indicate errors.',
)
status_msg: str = Field(
..., description='Specific error details or success message.'
)
class File(BaseModel):
bytes: Optional[int] = Field(None, description='File size in bytes')
created_at: Optional[int] = Field(
None, description='Unix timestamp when the file was created, in seconds'
)
download_url: Optional[str] = Field(
None, description='The URL to download the video'
)
backup_download_url: Optional[str] = Field(
None, description='The backup URL to download the video'
)
file_id: Optional[int] = Field(None, description='Unique identifier for the file')
filename: Optional[str] = Field(None, description='The name of the file')
purpose: Optional[str] = Field(None, description='The purpose of using the file')
class MinimaxFileRetrieveResponse(BaseModel):
base_resp: MinimaxBaseResponse
file: File
class MiniMaxModel(str, Enum):
T2V_01_Director = 'T2V-01-Director'
I2V_01_Director = 'I2V-01-Director'
S2V_01 = 'S2V-01'
I2V_01 = 'I2V-01'
I2V_01_live = 'I2V-01-live'
T2V_01 = 'T2V-01'
Hailuo_02 = 'MiniMax-Hailuo-02'
class Status6(str, Enum):
Queueing = 'Queueing'
Preparing = 'Preparing'
Processing = 'Processing'
Success = 'Success'
Fail = 'Fail'
class MinimaxTaskResultResponse(BaseModel):
base_resp: MinimaxBaseResponse
file_id: Optional[str] = Field(
None,
description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.',
)
status: Status6 = Field(
...,
description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).",
)
task_id: str = Field(..., description='The task ID being queried.')
class SubjectReferenceItem(BaseModel):
image: Optional[str] = Field(
None, description='URL or base64 encoding of the subject reference image.'
)
mask: Optional[str] = Field(
None,
description='URL or base64 encoding of the mask for the subject reference image.',
)
class MinimaxVideoGenerationRequest(BaseModel):
callback_url: Optional[str] = Field(
None,
description='Optional. URL to receive real-time status updates about the video generation task.',
)
first_frame_image: Optional[str] = Field(
None,
description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.',
)
model: MiniMaxModel = Field(
...,
description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01',
)
prompt: Optional[str] = Field(
None,
description='Description of the video. Should be less than 2000 characters. Supports camera movement instructions in [brackets].',
max_length=2000,
)
prompt_optimizer: Optional[bool] = Field(
True,
description='If true (default), the model will automatically optimize the prompt. Set to false for more precise control.',
)
subject_reference: Optional[list[SubjectReferenceItem]] = Field(
None,
description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.',
)
duration: Optional[int] = Field(
None,
description="The length of the output video in seconds."
)
resolution: Optional[str] = Field(
None,
description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels."
)
class MinimaxVideoGenerationResponse(BaseModel):
base_resp: MinimaxBaseResponse
task_id: str = Field(
..., description='The task ID for the asynchronous video generation task.'
)

View File

@@ -1,13 +1,20 @@
from __future__ import annotations
from comfy_api_nodes.apis import (
TripoModelVersion,
TripoTextureQuality,
)
from enum import Enum
from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field, RootModel
class TripoModelVersion(str, Enum):
v2_5_20250123 = 'v2.5-20250123'
v2_0_20240919 = 'v2.0-20240919'
v1_4_20240625 = 'v1.4-20240625'
class TripoTextureQuality(str, Enum):
standard = 'standard'
detailed = 'detailed'
class TripoStyle(str, Enum):
PERSON_TO_CARTOON = "person:person2cartoon"
ANIMAL_VENOM = "animal:venom"

View File

@@ -0,0 +1,111 @@
from typing import Optional, Union
from enum import Enum
from pydantic import BaseModel, Field
class Image2(BaseModel):
bytesBase64Encoded: str
gcsUri: Optional[str] = None
mimeType: Optional[str] = None
class Image3(BaseModel):
bytesBase64Encoded: Optional[str] = None
gcsUri: str
mimeType: Optional[str] = None
class Instance1(BaseModel):
image: Optional[Union[Image2, Image3]] = Field(
None, description='Optional image to guide video generation'
)
prompt: str = Field(..., description='Text description of the video')
class PersonGeneration1(str, Enum):
ALLOW = 'ALLOW'
BLOCK = 'BLOCK'
class Parameters1(BaseModel):
aspectRatio: Optional[str] = Field(None, examples=['16:9'])
durationSeconds: Optional[int] = None
enhancePrompt: Optional[bool] = None
generateAudio: Optional[bool] = Field(
None,
description='Generate audio for the video. Only supported by veo 3 models.',
)
negativePrompt: Optional[str] = None
personGeneration: Optional[PersonGeneration1] = None
sampleCount: Optional[int] = None
seed: Optional[int] = None
storageUri: Optional[str] = Field(
None, description='Optional Cloud Storage URI to upload the video'
)
class VeoGenVidRequest(BaseModel):
instances: Optional[list[Instance1]] = None
parameters: Optional[Parameters1] = None
class VeoGenVidResponse(BaseModel):
name: str = Field(
...,
description='Operation resource name',
examples=[
'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/a1b07c8e-7b5a-4aba-bb34-3e1ccb8afcc8'
],
)
class VeoGenVidPollRequest(BaseModel):
operationName: str = Field(
...,
description='Full operation name (from predict response)',
examples=[
'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/OPERATION_ID'
],
)
class Video(BaseModel):
bytesBase64Encoded: Optional[str] = Field(
None, description='Base64-encoded video content'
)
gcsUri: Optional[str] = Field(None, description='Cloud Storage URI of the video')
mimeType: Optional[str] = Field(None, description='Video MIME type')
class Error1(BaseModel):
code: Optional[int] = Field(None, description='Error code')
message: Optional[str] = Field(None, description='Error message')
class Response1(BaseModel):
field_type: Optional[str] = Field(
None,
alias='@type',
examples=[
'type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse'
],
)
raiMediaFilteredCount: Optional[int] = Field(
None, description='Count of media filtered by responsible AI policies'
)
raiMediaFilteredReasons: Optional[list[str]] = Field(
None, description='Reasons why media was filtered by responsible AI policies'
)
videos: Optional[list[Video]] = None
class VeoGenVidPollResponse(BaseModel):
done: Optional[bool] = None
error: Optional[Error1] = Field(
None, description='Error details if operation failed'
)
name: Optional[str] = None
response: Optional[Response1] = Field(
None, description='The actual prediction response if done is true'
)

View File

@@ -1,146 +1,46 @@
import asyncio
import io
from inspect import cleandoc
from typing import Union, Optional
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.bfl_api import (
BFLStatus,
BFLFluxExpandImageRequest,
BFLFluxFillImageRequest,
BFLFluxCannyImageRequest,
BFLFluxDepthImageRequest,
BFLFluxProGenerateRequest,
BFLFluxKontextProGenerateRequest,
BFLFluxProUltraGenerateRequest,
BFLFluxProGenerateRequest,
BFLFluxProGenerateResponse,
BFLFluxProUltraGenerateRequest,
BFLFluxStatusResponse,
BFLStatus,
)
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
)
from comfy_api_nodes.apinode_utils import (
downscale_image_tensor,
validate_aspect_ratio,
process_image_response,
download_url_to_image_tensor,
poll_op,
resize_mask_to_image,
sync_op,
tensor_to_base64_string,
validate_aspect_ratio_string,
validate_string,
)
import numpy as np
from PIL import Image
import aiohttp
import torch
import base64
import time
from server import PromptServer
def convert_mask_to_image(mask: torch.Tensor):
"""
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
"""
mask = mask.unsqueeze(-1)
mask = torch.cat([mask]*3, dim=-1)
mask = torch.cat([mask] * 3, dim=-1)
return mask
async def handle_bfl_synchronous_operation(
operation: SynchronousOperation,
timeout_bfl_calls=360,
node_id: Union[str, None] = None,
):
response_api: BFLFluxProGenerateResponse = await operation.execute()
return await _poll_until_generated(
response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id
)
async def _poll_until_generated(
polling_url: str, timeout=360, node_id: Union[str, None] = None
):
# used bfl-comfy-nodes to verify code implementation:
# https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main
start_time = time.time()
retries_404 = 0
max_retries_404 = 5
retry_404_seconds = 2
retry_202_seconds = 2
retry_pending_seconds = 1
async with aiohttp.ClientSession() as session:
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
while True:
if node_id:
time_elapsed = time.time() - start_time
PromptServer.instance.send_progress_text(
f"Generating ({time_elapsed:.0f}s)", node_id
)
async with session.get(polling_url) as response:
if response.status == 200:
result = await response.json()
if result["status"] == BFLStatus.ready:
img_url = result["result"]["sample"]
if node_id:
PromptServer.instance.send_progress_text(
f"Result URL: {img_url}", node_id
)
async with session.get(img_url) as img_resp:
return process_image_response(await img_resp.content.read())
elif result["status"] in [
BFLStatus.request_moderated,
BFLStatus.content_moderated,
]:
status = result["status"]
raise Exception(
f"BFL API did not return an image due to: {status}."
)
elif result["status"] == BFLStatus.error:
raise Exception(f"BFL API encountered an error: {result}.")
elif result["status"] == BFLStatus.pending:
await asyncio.sleep(retry_pending_seconds)
continue
elif response.status == 404:
if retries_404 < max_retries_404:
retries_404 += 1
await asyncio.sleep(retry_404_seconds)
continue
raise Exception(
f"BFL API could not find task after {max_retries_404} tries."
)
elif response.status == 202:
await asyncio.sleep(retry_202_seconds)
elif time.time() - start_time > timeout:
raise Exception(
f"BFL API experienced a timeout; could not return request under {timeout} seconds."
)
else:
raise Exception(f"BFL API encountered an error: {response.json()}")
def convert_image_to_base64(image: torch.Tensor):
scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048)
# remove batch dimension if present
if len(scaled_image.shape) > 3:
scaled_image = scaled_image[0]
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="PNG")
return base64.b64encode(img_byte_arr.getvalue()).decode()
class FluxProUltraImageNode(IO.ComfyNode):
"""
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
"""
MINIMUM_RATIO = 1 / 4
MAXIMUM_RATIO = 4 / 1
MINIMUM_RATIO_STR = "1:4"
MAXIMUM_RATIO_STR = "4:1"
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
@@ -158,7 +58,9 @@ class FluxProUltraImageNode(IO.ComfyNode):
IO.Boolean.Input(
"prompt_upsampling",
default=False,
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, "
"but results are nondeterministic (same seed will not produce exactly the same result).",
),
IO.Int.Input(
"seed",
@@ -203,16 +105,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
@classmethod
def validate_inputs(cls, aspect_ratio: str):
try:
validate_aspect_ratio(
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
)
except Exception as e:
return str(e)
validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1))
return True
@classmethod
@@ -220,49 +113,44 @@ class FluxProUltraImageNode(IO.ComfyNode):
cls,
prompt: str,
aspect_ratio: str,
prompt_upsampling=False,
raw=False,
seed=0,
image_prompt=None,
image_prompt_strength=0.1,
prompt_upsampling: bool = False,
raw: bool = False,
seed: int = 0,
image_prompt: Optional[torch.Tensor] = None,
image_prompt_strength: float = 0.1,
) -> IO.NodeOutput:
if image_prompt is None:
validate_string(prompt, strip_whitespace=False)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/bfl/flux-pro-1.1-ultra/generate",
method=HttpMethod.POST,
request_model=BFLFluxProUltraGenerateRequest,
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxProUltraGenerateRequest(
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bfl/flux-pro-1.1-ultra/generate", method="POST"),
response_model=BFLFluxProGenerateResponse,
data=BFLFluxProUltraGenerateRequest(
prompt=prompt,
prompt_upsampling=prompt_upsampling,
seed=seed,
aspect_ratio=validate_aspect_ratio(
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
),
aspect_ratio=aspect_ratio,
raw=raw,
image_prompt=(
image_prompt
if image_prompt is None
else convert_image_to_base64(image_prompt)
),
image_prompt_strength=(
None if image_prompt is None else round(image_prompt_strength, 2)
),
image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)),
image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)),
),
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
return IO.NodeOutput(output_image)
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxKontextProImageNode(IO.ComfyNode):
@@ -270,11 +158,6 @@ class FluxKontextProImageNode(IO.ComfyNode):
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
"""
MINIMUM_RATIO = 1 / 4
MAXIMUM_RATIO = 4 / 1
MINIMUM_RATIO_STR = "1:4"
MAXIMUM_RATIO_STR = "4:1"
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
@@ -347,46 +230,43 @@ class FluxKontextProImageNode(IO.ComfyNode):
aspect_ratio: str,
guidance: float,
steps: int,
input_image: Optional[torch.Tensor]=None,
input_image: Optional[torch.Tensor] = None,
seed=0,
prompt_upsampling=False,
) -> IO.NodeOutput:
aspect_ratio = validate_aspect_ratio(
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
)
validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1))
if input_image is None:
validate_string(prompt, strip_whitespace=False)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=cls.BFL_PATH,
method=HttpMethod.POST,
request_model=BFLFluxKontextProGenerateRequest,
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxKontextProGenerateRequest(
initial_response = await sync_op(
cls,
ApiEndpoint(path=cls.BFL_PATH, method="POST"),
response_model=BFLFluxProGenerateResponse,
data=BFLFluxKontextProGenerateRequest(
prompt=prompt,
prompt_upsampling=prompt_upsampling,
guidance=round(guidance, 1),
steps=steps,
seed=seed,
aspect_ratio=aspect_ratio,
input_image=(
input_image
if input_image is None
else convert_image_to_base64(input_image)
)
input_image=(input_image if input_image is None else tensor_to_base64_string(input_image)),
),
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
return IO.NodeOutput(output_image)
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxKontextMaxImageNode(FluxKontextProImageNode):
@@ -422,7 +302,9 @@ class FluxProImageNode(IO.ComfyNode):
IO.Boolean.Input(
"prompt_upsampling",
default=False,
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, "
"but results are nondeterministic (same seed will not produce exactly the same result).",
),
IO.Int.Input(
"width",
@@ -481,20 +363,15 @@ class FluxProImageNode(IO.ComfyNode):
image_prompt=None,
# image_prompt_strength=0.1,
) -> IO.NodeOutput:
image_prompt = (
image_prompt
if image_prompt is None
else convert_image_to_base64(image_prompt)
)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
image_prompt = image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)
initial_response = await sync_op(
cls,
ApiEndpoint(
path="/proxy/bfl/flux-pro-1.1/generate",
method=HttpMethod.POST,
request_model=BFLFluxProGenerateRequest,
response_model=BFLFluxProGenerateResponse,
method="POST",
),
request=BFLFluxProGenerateRequest(
response_model=BFLFluxProGenerateResponse,
data=BFLFluxProGenerateRequest(
prompt=prompt,
prompt_upsampling=prompt_upsampling,
width=width,
@@ -502,13 +379,23 @@ class FluxProImageNode(IO.ComfyNode):
seed=seed,
image_prompt=image_prompt,
),
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
return IO.NodeOutput(output_image)
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxProExpandNode(IO.ComfyNode):
@@ -534,7 +421,9 @@ class FluxProExpandNode(IO.ComfyNode):
IO.Boolean.Input(
"prompt_upsampling",
default=False,
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, "
"but results are nondeterministic (same seed will not produce exactly the same result).",
),
IO.Int.Input(
"top",
@@ -610,16 +499,11 @@ class FluxProExpandNode(IO.ComfyNode):
guidance: float,
seed=0,
) -> IO.NodeOutput:
image = convert_image_to_base64(image)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/bfl/flux-pro-1.0-expand/generate",
method=HttpMethod.POST,
request_model=BFLFluxExpandImageRequest,
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxExpandImageRequest(
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bfl/flux-pro-1.0-expand/generate", method="POST"),
response_model=BFLFluxProGenerateResponse,
data=BFLFluxExpandImageRequest(
prompt=prompt,
prompt_upsampling=prompt_upsampling,
top=top,
@@ -629,16 +513,25 @@ class FluxProExpandNode(IO.ComfyNode):
steps=steps,
guidance=guidance,
seed=seed,
image=image,
image=tensor_to_base64_string(image),
),
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
return IO.NodeOutput(output_image)
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxProFillNode(IO.ComfyNode):
@@ -665,7 +558,9 @@ class FluxProFillNode(IO.ComfyNode):
IO.Boolean.Input(
"prompt_upsampling",
default=False,
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, "
"but results are nondeterministic (same seed will not produce exactly the same result).",
),
IO.Float.Input(
"guidance",
@@ -712,272 +607,37 @@ class FluxProFillNode(IO.ComfyNode):
) -> IO.NodeOutput:
# prepare mask
mask = resize_mask_to_image(mask, image)
mask = convert_image_to_base64(convert_mask_to_image(mask))
# make sure image will have alpha channel removed
image = convert_image_to_base64(image[:, :, :, :3])
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/bfl/flux-pro-1.0-fill/generate",
method=HttpMethod.POST,
request_model=BFLFluxFillImageRequest,
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxFillImageRequest(
mask = tensor_to_base64_string(convert_mask_to_image(mask))
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bfl/flux-pro-1.0-fill/generate", method="POST"),
response_model=BFLFluxProGenerateResponse,
data=BFLFluxFillImageRequest(
prompt=prompt,
prompt_upsampling=prompt_upsampling,
steps=steps,
guidance=guidance,
seed=seed,
image=image,
image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed
mask=mask,
),
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
return IO.NodeOutput(output_image)
class FluxProCannyNode(IO.ComfyNode):
"""
Generate image using a control image (canny).
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="FluxProCannyNode",
display_name="Flux.1 Canny Control Image",
category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("control_image"),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
IO.Boolean.Input(
"prompt_upsampling",
default=False,
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
),
IO.Float.Input(
"canny_low_threshold",
default=0.1,
min=0.01,
max=0.99,
step=0.01,
tooltip="Low threshold for Canny edge detection; ignored if skip_processing is True",
),
IO.Float.Input(
"canny_high_threshold",
default=0.4,
min=0.01,
max=0.99,
step=0.01,
tooltip="High threshold for Canny edge detection; ignored if skip_processing is True",
),
IO.Boolean.Input(
"skip_preprocessing",
default=False,
tooltip="Whether to skip preprocessing; set to True if control_image already is canny-fied, False if it is a raw image.",
),
IO.Float.Input(
"guidance",
default=30,
min=1,
max=100,
tooltip="Guidance strength for the image generation process",
),
IO.Int.Input(
"steps",
default=50,
min=15,
max=50,
tooltip="Number of steps for the image generation process",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
queued_statuses=[],
)
@classmethod
async def execute(
cls,
control_image: torch.Tensor,
prompt: str,
prompt_upsampling: bool,
canny_low_threshold: float,
canny_high_threshold: float,
skip_preprocessing: bool,
steps: int,
guidance: float,
seed=0,
) -> IO.NodeOutput:
control_image = convert_image_to_base64(control_image[:, :, :, :3])
preprocessed_image = None
# scale canny threshold between 0-500, to match BFL's API
def scale_value(value: float, min_val=0, max_val=500):
return min_val + value * (max_val - min_val)
canny_low_threshold = int(round(scale_value(canny_low_threshold)))
canny_high_threshold = int(round(scale_value(canny_high_threshold)))
if skip_preprocessing:
preprocessed_image = control_image
control_image = None
canny_low_threshold = None
canny_high_threshold = None
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/bfl/flux-pro-1.0-canny/generate",
method=HttpMethod.POST,
request_model=BFLFluxCannyImageRequest,
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxCannyImageRequest(
prompt=prompt,
prompt_upsampling=prompt_upsampling,
steps=steps,
guidance=guidance,
seed=seed,
control_image=control_image,
canny_low_threshold=canny_low_threshold,
canny_high_threshold=canny_high_threshold,
preprocessed_image=preprocessed_image,
),
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
return IO.NodeOutput(output_image)
class FluxProDepthNode(IO.ComfyNode):
"""
Generate image using a control image (depth).
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="FluxProDepthNode",
display_name="Flux.1 Depth Control Image",
category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("control_image"),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
IO.Boolean.Input(
"prompt_upsampling",
default=False,
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
),
IO.Boolean.Input(
"skip_preprocessing",
default=False,
tooltip="Whether to skip preprocessing; set to True if control_image already is depth-ified, False if it is a raw image.",
),
IO.Float.Input(
"guidance",
default=15,
min=1,
max=100,
tooltip="Guidance strength for the image generation process",
),
IO.Int.Input(
"steps",
default=50,
min=15,
max=50,
tooltip="Number of steps for the image generation process",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
control_image: torch.Tensor,
prompt: str,
prompt_upsampling: bool,
skip_preprocessing: bool,
steps: int,
guidance: float,
seed=0,
) -> IO.NodeOutput:
control_image = convert_image_to_base64(control_image[:,:,:,:3])
preprocessed_image = None
if skip_preprocessing:
preprocessed_image = control_image
control_image = None
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/bfl/flux-pro-1.0-depth/generate",
method=HttpMethod.POST,
request_model=BFLFluxDepthImageRequest,
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxDepthImageRequest(
prompt=prompt,
prompt_upsampling=prompt_upsampling,
steps=steps,
guidance=guidance,
seed=seed,
control_image=control_image,
preprocessed_image=preprocessed_image,
),
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
return IO.NodeOutput(output_image)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class BFLExtension(ComfyExtension):
@@ -990,8 +650,6 @@ class BFLExtension(ComfyExtension):
FluxKontextMaxImageNode,
FluxProExpandNode,
FluxProFillNode,
FluxProCannyNode,
FluxProDepthNode,
]

View File

@@ -1,35 +1,27 @@
import logging
import math
from enum import Enum
from typing import Literal, Optional, Type, Union
from typing_extensions import override
from typing import Literal, Optional, Union
import torch
from pydantic import BaseModel, Field
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api_nodes.util.validation_utils import (
validate_image_aspect_ratio_range,
get_number_of_images,
validate_image_dimensions,
)
from comfy_api_nodes.apis.client import (
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.util import (
ApiEndpoint,
EmptyRequest,
HttpMethod,
SynchronousOperation,
PollingOperation,
T,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_image_tensor,
download_url_to_video_output,
upload_images_to_comfyapi,
validate_string,
get_number_of_images,
image_tensor_pair_to_batch,
poll_op,
sync_op,
upload_images_to_comfyapi,
validate_image_aspect_ratio,
validate_image_dimensions,
validate_string,
)
BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations"
# Long-running tasks endpoints(e.g., video)
@@ -46,13 +38,14 @@ class Image2ImageModelName(str, Enum):
class Text2VideoModelName(str, Enum):
seedance_1_pro = "seedance-1-0-pro-250528"
seedance_1_pro = "seedance-1-0-pro-250528"
seedance_1_lite = "seedance-1-0-lite-t2v-250428"
class Image2VideoModelName(str, Enum):
"""note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757"""
seedance_1_pro = "seedance-1-0-pro-250528"
seedance_1_pro = "seedance-1-0-pro-250528"
seedance_1_lite = "seedance-1-0-lite-i2v-250428"
@@ -208,35 +201,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
return None
async def poll_until_finished(
auth_kwargs: dict[str, str],
task_id: str,
estimated_duration: Optional[int] = None,
node_id: Optional[str] = None,
) -> TaskStatusResponse:
"""Polls the ByteDance API endpoint until the task reaches a terminal state, then returns the response."""
return await PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
completed_statuses=[
"succeeded",
],
failed_statuses=[
"cancelled",
"failed",
],
status_extractor=lambda response: response.status,
auth_kwargs=auth_kwargs,
result_url_extractor=get_video_url_from_task_status,
estimated_duration=estimated_duration,
node_id=node_id,
).execute()
class ByteDanceImageNode(IO.ComfyNode):
@classmethod
@@ -303,7 +267,7 @@ class ByteDanceImageNode(IO.ComfyNode):
IO.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the image",
tooltip='Whether to add an "AI generated" watermark to the image',
optional=True,
),
],
@@ -341,8 +305,7 @@ class ByteDanceImageNode(IO.ComfyNode):
w, h = width, height
if not (512 <= w <= 2048) or not (512 <= h <= 2048):
raise ValueError(
f"Custom size out of range: {w}x{h}. "
"Both width and height must be between 512 and 2048 pixels."
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 512 and 2048 pixels."
)
payload = Text2ImageTaskCreationRequest(
@@ -353,20 +316,12 @@ class ByteDanceImageNode(IO.ComfyNode):
guidance_scale=guidance_scale,
watermark=watermark,
)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
response = await SynchronousOperation(
endpoint=ApiEndpoint(
path=BYTEPLUS_IMAGE_ENDPOINT,
method=HttpMethod.POST,
request_model=Text2ImageTaskCreationRequest,
response_model=ImageTaskCreationResponse,
),
request=payload,
auth_kwargs=auth_kwargs,
).execute()
response = await sync_op(
cls,
ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"),
data=payload,
response_model=ImageTaskCreationResponse,
)
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response)))
@@ -420,7 +375,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
IO.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the image",
tooltip='Whether to add an "AI generated" watermark to the image',
optional=True,
),
],
@@ -448,17 +403,8 @@ class ByteDanceImageEditNode(IO.ComfyNode):
validate_string(prompt, strip_whitespace=True, min_length=1)
if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.")
validate_image_aspect_ratio_range(image, (1, 3), (3, 1))
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
source_url = (await upload_images_to_comfyapi(
image,
max_images=1,
mime_type="image/png",
auth_kwargs=auth_kwargs,
))[0]
validate_image_aspect_ratio(image, (1, 3), (3, 1))
source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0]
payload = Image2ImageTaskCreationRequest(
model=model,
prompt=prompt,
@@ -467,16 +413,12 @@ class ByteDanceImageEditNode(IO.ComfyNode):
guidance_scale=guidance_scale,
watermark=watermark,
)
response = await SynchronousOperation(
endpoint=ApiEndpoint(
path=BYTEPLUS_IMAGE_ENDPOINT,
method=HttpMethod.POST,
request_model=Image2ImageTaskCreationRequest,
response_model=ImageTaskCreationResponse,
),
request=payload,
auth_kwargs=auth_kwargs,
).execute()
response = await sync_op(
cls,
ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"),
data=payload,
response_model=ImageTaskCreationResponse,
)
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response)))
@@ -504,7 +446,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
IO.Image.Input(
"image",
tooltip="Input image(s) for image-to-image generation. "
"List of 1-10 images for single or multi-reference generation.",
"List of 1-10 images for single or multi-reference generation.",
optional=True,
),
IO.Combo.Input(
@@ -534,9 +476,9 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
"sequential_image_generation",
options=["disabled", "auto"],
tooltip="Group image generation mode. "
"'disabled' generates a single image. "
"'auto' lets the model decide whether to generate multiple related images "
"(e.g., story scenes, character variations).",
"'disabled' generates a single image. "
"'auto' lets the model decide whether to generate multiple related images "
"(e.g., story scenes, character variations).",
optional=True,
),
IO.Int.Input(
@@ -547,7 +489,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
step=1,
display_mode=IO.NumberDisplay.number,
tooltip="Maximum number of images to generate when sequential_image_generation='auto'. "
"Total images (input + generated) cannot exceed 15.",
"Total images (input + generated) cannot exceed 15.",
optional=True,
),
IO.Int.Input(
@@ -564,7 +506,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
IO.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the image.",
tooltip='Whether to add an "AI generated" watermark to the image.',
optional=True,
),
IO.Boolean.Input(
@@ -611,8 +553,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
w, h = width, height
if not (1024 <= w <= 4096) or not (1024 <= h <= 4096):
raise ValueError(
f"Custom size out of range: {w}x{h}. "
"Both width and height must be between 1024 and 4096 pixels."
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
)
n_input_images = get_number_of_images(image) if image is not None else 0
if n_input_images > 10:
@@ -621,41 +562,31 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
raise ValueError(
"The maximum number of generated images plus the number of reference images cannot exceed 15."
)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
reference_images_urls = []
if n_input_images:
for i in image:
validate_image_aspect_ratio_range(i, (1, 3), (3, 1))
reference_images_urls = (await upload_images_to_comfyapi(
validate_image_aspect_ratio(i, (1, 3), (3, 1))
reference_images_urls = await upload_images_to_comfyapi(
cls,
image,
max_images=n_input_images,
mime_type="image/png",
auth_kwargs=auth_kwargs,
))
payload = Seedream4TaskCreationRequest(
model=model,
prompt=prompt,
image=reference_images_urls,
size=f"{w}x{h}",
seed=seed,
sequential_image_generation=sequential_image_generation,
sequential_image_generation_options=Seedream4Options(max_images=max_images),
watermark=watermark,
)
response = await SynchronousOperation(
endpoint=ApiEndpoint(
path=BYTEPLUS_IMAGE_ENDPOINT,
method=HttpMethod.POST,
request_model=Seedream4TaskCreationRequest,
response_model=ImageTaskCreationResponse,
)
response = await sync_op(
cls,
ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"),
response_model=ImageTaskCreationResponse,
data=Seedream4TaskCreationRequest(
model=model,
prompt=prompt,
image=reference_images_urls,
size=f"{w}x{h}",
seed=seed,
sequential_image_generation=sequential_image_generation,
sequential_image_generation_options=Seedream4Options(max_images=max_images),
watermark=watermark,
),
request=payload,
auth_kwargs=auth_kwargs,
).execute()
)
if len(response.data) == 1:
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response)))
urls = [str(d["url"]) for d in response.data if isinstance(d, dict) and "url" in d]
@@ -719,13 +650,13 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
"camera_fixed",
default=False,
tooltip="Specifies whether to fix the camera. The platform appends an instruction "
"to fix the camera to your prompt, but does not guarantee the actual effect.",
"to fix the camera to your prompt, but does not guarantee the actual effect.",
optional=True,
),
IO.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the video.",
tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True,
),
],
@@ -764,19 +695,9 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
f"--camerafixed {str(camera_fixed).lower()} "
f"--watermark {str(watermark).lower()}"
)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
return await process_video_task(
request_model=Text2VideoTaskCreationRequest,
payload=Text2VideoTaskCreationRequest(
model=model,
content=[TaskTextContent(text=prompt)],
),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
cls,
payload=Text2VideoTaskCreationRequest(model=model, content=[TaskTextContent(text=prompt)]),
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
)
@@ -840,13 +761,13 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
"camera_fixed",
default=False,
tooltip="Specifies whether to fix the camera. The platform appends an instruction "
"to fix the camera to your prompt, but does not guarantee the actual effect.",
"to fix the camera to your prompt, but does not guarantee the actual effect.",
optional=True,
),
IO.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the video.",
tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True,
),
],
@@ -877,15 +798,9 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
validate_string(prompt, strip_whitespace=True, min_length=1)
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth_kwargs))[0]
validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
prompt = (
f"{prompt} "
f"--resolution {resolution} "
@@ -897,13 +812,11 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
)
return await process_video_task(
request_model=Image2VideoTaskCreationRequest,
cls,
payload=Image2VideoTaskCreationRequest(
model=model,
content=[TaskTextContent(text=prompt), TaskImageContent(image_url=TaskImageContentUrl(url=image_url))],
),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
)
@@ -971,13 +884,13 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
"camera_fixed",
default=False,
tooltip="Specifies whether to fix the camera. The platform appends an instruction "
"to fix the camera to your prompt, but does not guarantee the actual effect.",
"to fix the camera to your prompt, but does not guarantee the actual effect.",
optional=True,
),
IO.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the video.",
tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True,
),
],
@@ -1010,18 +923,13 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
for i in (first_frame, last_frame):
validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000)
validate_image_aspect_ratio_range(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
validate_image_aspect_ratio(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
download_urls = await upload_images_to_comfyapi(
cls,
image_tensor_pair_to_batch(first_frame, last_frame),
max_images=2,
mime_type="image/png",
auth_kwargs=auth_kwargs,
)
prompt = (
@@ -1035,7 +943,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
)
return await process_video_task(
request_model=Image2VideoTaskCreationRequest,
cls,
payload=Image2VideoTaskCreationRequest(
model=model,
content=[
@@ -1044,8 +952,6 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[1])), role="last_frame"),
],
),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
)
@@ -1108,7 +1014,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
IO.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the video.",
tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True,
),
],
@@ -1139,17 +1045,9 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"])
for image in images:
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
image_urls = await upload_images_to_comfyapi(
images, max_images=4, mime_type="image/png", auth_kwargs=auth_kwargs
)
validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png")
prompt = (
f"{prompt} "
f"--resolution {resolution} "
@@ -1160,42 +1058,32 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
)
x = [
TaskTextContent(text=prompt),
*[TaskImageContent(image_url=TaskImageContentUrl(url=str(i)), role="reference_image") for i in image_urls]
*[TaskImageContent(image_url=TaskImageContentUrl(url=str(i)), role="reference_image") for i in image_urls],
]
return await process_video_task(
request_model=Image2VideoTaskCreationRequest,
payload=Image2VideoTaskCreationRequest(
model=model,
content=x,
),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
cls,
payload=Image2VideoTaskCreationRequest(model=model, content=x),
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
)
async def process_video_task(
request_model: Type[T],
cls: type[IO.ComfyNode],
payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
auth_kwargs: dict,
node_id: str,
estimated_duration: Optional[int],
) -> IO.NodeOutput:
initial_response = await SynchronousOperation(
endpoint=ApiEndpoint(
path=BYTEPLUS_TASK_ENDPOINT,
method=HttpMethod.POST,
request_model=request_model,
response_model=TaskCreationResponse,
),
request=payload,
auth_kwargs=auth_kwargs,
).execute()
response = await poll_until_finished(
auth_kwargs,
initial_response.id,
initial_response = await sync_op(
cls,
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
data=payload,
response_model=TaskCreationResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
status_extractor=lambda r: r.status,
estimated_duration=estimated_duration,
node_id=node_id,
response_model=TaskStatusResponse,
)
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response)))
@@ -1221,5 +1109,6 @@ class ByteDanceExtension(ComfyExtension):
ByteDanceImageReferenceNode,
]
async def comfy_entrypoint() -> ByteDanceExtension:
return ByteDanceExtension()

View File

@@ -2,45 +2,47 @@
API Nodes for Gemini Multimodal LLM Usage via Remote API
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
"""
from __future__ import annotations
import json
import time
import os
import uuid
import base64
from io import BytesIO
import json
import os
import time
import uuid
from enum import Enum
from typing import Optional, Literal
from io import BytesIO
from typing import Literal, Optional
import torch
from typing_extensions import override
import folder_paths
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
from server import PromptServer
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.util import VideoCodec, VideoContainer
from comfy_api_nodes.apis import (
GeminiContent,
GeminiGenerateContentRequest,
GeminiGenerateContentResponse,
GeminiInlineData,
GeminiPart,
GeminiMimeType,
GeminiPart,
)
from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest, GeminiImageConfig
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.apis.gemini_api import (
GeminiImageConfig,
GeminiImageGenerateContentRequest,
GeminiImageGenerationConfig,
)
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
)
from comfy_api_nodes.apinode_utils import (
validate_string,
audio_to_base64_string,
video_to_base64_string,
tensor_to_base64_string,
bytesio_to_image_tensor,
sync_op,
tensor_to_base64_string,
validate_string,
video_to_base64_string,
)
from comfy_api.util import VideoContainer, VideoCodec
from server import PromptServer
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
@@ -66,50 +68,6 @@ class GeminiImageModel(str, Enum):
gemini_2_5_flash_image = "gemini-2.5-flash-image"
def get_gemini_endpoint(
model: GeminiModel,
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
"""
Get the API endpoint for a given Gemini model.
Args:
model: The Gemini model to use, either as enum or string value.
Returns:
ApiEndpoint configured for the specific Gemini model.
"""
if isinstance(model, str):
model = GeminiModel(model)
return ApiEndpoint(
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
method=HttpMethod.POST,
request_model=GeminiGenerateContentRequest,
response_model=GeminiGenerateContentResponse,
)
def get_gemini_image_endpoint(
model: GeminiImageModel,
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
"""
Get the API endpoint for a given Gemini model.
Args:
model: The Gemini model to use, either as enum or string value.
Returns:
ApiEndpoint configured for the specific Gemini model.
"""
if isinstance(model, str):
model = GeminiImageModel(model)
return ApiEndpoint(
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
method=HttpMethod.POST,
request_model=GeminiImageGenerateContentRequest,
response_model=GeminiGenerateContentResponse,
)
def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
"""
Convert image tensor input to Gemini API compatible parts.
@@ -122,9 +80,7 @@ def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
"""
image_parts: list[GeminiPart] = []
for image_index in range(image_input.shape[0]):
image_as_b64 = tensor_to_base64_string(
image_input[image_index].unsqueeze(0)
)
image_as_b64 = tensor_to_base64_string(image_input[image_index].unsqueeze(0))
image_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
@@ -136,37 +92,7 @@ def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
return image_parts
def create_text_part(text: str) -> GeminiPart:
"""
Create a text part for the Gemini API request.
Args:
text: The text content to include in the request.
Returns:
A GeminiPart object with the text content.
"""
return GeminiPart(text=text)
def get_parts_from_response(
response: GeminiGenerateContentResponse
) -> list[GeminiPart]:
"""
Extract all parts from the Gemini API response.
Args:
response: The API response from Gemini.
Returns:
List of response parts from the first candidate.
"""
return response.candidates[0].content.parts
def get_parts_by_type(
response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
) -> list[GeminiPart]:
def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]:
"""
Filter response parts by their type.
@@ -178,14 +104,10 @@ def get_parts_by_type(
List of response parts matching the requested type.
"""
parts = []
for part in get_parts_from_response(response):
for part in response.candidates[0].content.parts:
if part_type == "text" and hasattr(part, "text") and part.text:
parts.append(part)
elif (
hasattr(part, "inlineData")
and part.inlineData
and part.inlineData.mimeType == part_type
):
elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type:
parts.append(part)
# Skip parts that don't match the requested type
return parts
@@ -213,11 +135,11 @@ def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Te
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
image_tensors.append(returned_image)
if len(image_tensors) == 0:
return torch.zeros((1,1024,1024,4))
return torch.zeros((1, 1024, 1024, 4))
return torch.cat(image_tensors, dim=0)
class GeminiNode(ComfyNodeABC):
class GeminiNode(IO.ComfyNode):
"""
Node to generate text responses from a Gemini model.
@@ -228,96 +150,79 @@ class GeminiNode(ComfyNodeABC):
"""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Text inputs to the model, used to generate a response. You can include detailed instructions, questions, or context for the model.",
},
def define_schema(cls):
return IO.Schema(
node_id="GeminiNode",
display_name="Google Gemini",
category="api node/text/Gemini",
description="Generate text responses with Google's Gemini AI model. "
"You can provide multiple types of inputs (text, images, audio, video) "
"as context for generating more relevant and meaningful responses.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text inputs to the model, used to generate a response. "
"You can include detailed instructions, questions, or context for the model.",
),
"model": (
IO.COMBO,
{
"tooltip": "The Gemini model to use for generating responses.",
"options": [model.value for model in GeminiModel],
"default": GeminiModel.gemini_2_5_pro.value,
},
IO.Combo.Input(
"model",
options=GeminiModel,
default=GeminiModel.gemini_2_5_pro,
tooltip="The Gemini model to use for generating responses.",
),
"seed": (
IO.INT,
{
"default": 42,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
},
IO.Int.Input(
"seed",
default=42,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="When seed is fixed to a specific value, the model makes a best effort to provide "
"the same response for repeated requests. Deterministic output isn't guaranteed. "
"Also, changing the model or parameter settings, such as the temperature, "
"can cause variations in the response even when you use the same seed value. "
"By default, a random seed value is used.",
),
},
"optional": {
"images": (
IO.IMAGE,
{
"default": None,
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
},
IO.Image.Input(
"images",
optional=True,
tooltip="Optional image(s) to use as context for the model. "
"To include multiple images, you can use the Batch Images node.",
),
"audio": (
IO.AUDIO,
{
"tooltip": "Optional audio to use as context for the model.",
"default": None,
},
IO.Audio.Input(
"audio",
optional=True,
tooltip="Optional audio to use as context for the model.",
),
"video": (
IO.VIDEO,
{
"tooltip": "Optional video to use as context for the model.",
"default": None,
},
IO.Video.Input(
"video",
optional=True,
tooltip="Optional video to use as context for the model.",
),
"files": (
"GEMINI_INPUT_FILES",
{
"default": None,
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
},
IO.Custom("GEMINI_INPUT_FILES").Input(
"files",
optional=True,
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Generate text responses with Google's Gemini AI model. You can provide multiple types of inputs (text, images, audio, video) as context for generating more relevant and meaningful responses."
RETURN_TYPES = ("STRING",)
FUNCTION = "api_call"
CATEGORY = "api node/text/Gemini"
API_NODE = True
def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]:
"""
Convert video input to Gemini API compatible parts.
Args:
video_input: Video tensor from ComfyUI.
**kwargs: Additional arguments to pass to the conversion function.
Returns:
List of GeminiPart objects containing the encoded video.
"""
base_64_string = video_to_base64_string(
video_input,
container_format=VideoContainer.MP4,
codec=VideoCodec.H264
],
outputs=[
IO.String.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
"""Convert video input to Gemini API compatible parts."""
base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264)
return [
GeminiPart(
inlineData=GeminiInlineData(
@@ -327,7 +232,8 @@ class GeminiNode(ComfyNodeABC):
)
]
def create_audio_parts(self, audio_input: IO.AUDIO) -> list[GeminiPart]:
@classmethod
def create_audio_parts(cls, audio_input: Input.Audio) -> list[GeminiPart]:
"""
Convert audio input to Gemini API compatible parts.
@@ -340,10 +246,10 @@ class GeminiNode(ComfyNodeABC):
audio_parts: list[GeminiPart] = []
for batch_index in range(audio_input["waveform"].shape[0]):
# Recreate an IO.AUDIO object for the given batch dimension index
audio_at_index = {
"waveform": audio_input["waveform"][batch_index].unsqueeze(0),
"sample_rate": audio_input["sample_rate"],
}
audio_at_index = Input.Audio(
waveform=audio_input["waveform"][batch_index].unsqueeze(0),
sample_rate=audio_input["sample_rate"],
)
# Convert to MP3 format for compatibility with Gemini API
audio_bytes = audio_to_base64_string(
audio_at_index,
@@ -360,38 +266,38 @@ class GeminiNode(ComfyNodeABC):
)
return audio_parts
async def api_call(
self,
@classmethod
async def execute(
cls,
prompt: str,
model: GeminiModel,
images: Optional[IO.IMAGE] = None,
audio: Optional[IO.AUDIO] = None,
video: Optional[IO.VIDEO] = None,
model: str,
seed: int,
images: Optional[torch.Tensor] = None,
audio: Optional[Input.Audio] = None,
video: Optional[Input.Video] = None,
files: Optional[list[GeminiPart]] = None,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[str]:
# Validate inputs
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
# Create parts list with text prompt as the first part
parts: list[GeminiPart] = [create_text_part(prompt)]
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
# Add other modal parts
if images is not None:
image_parts = create_image_parts(images)
parts.extend(image_parts)
if audio is not None:
parts.extend(self.create_audio_parts(audio))
parts.extend(cls.create_audio_parts(audio))
if video is not None:
parts.extend(self.create_video_parts(video))
parts.extend(cls.create_video_parts(video))
if files is not None:
parts.extend(files)
# Create response
response = await SynchronousOperation(
endpoint=get_gemini_endpoint(model),
request=GeminiGenerateContentRequest(
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
data=GeminiGenerateContentRequest(
contents=[
GeminiContent(
role="user",
@@ -399,15 +305,15 @@ class GeminiNode(ComfyNodeABC):
)
]
),
auth_kwargs=kwargs,
).execute()
response_model=GeminiGenerateContentResponse,
)
# Get result output
output_text = get_text_from_response(response)
if unique_id and output_text:
if output_text:
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
render_spec = {
"node_id": unique_id,
"node_id": cls.hidden.unique_id,
"component": "ChatHistoryWidget",
"props": {
"history": json.dumps(
@@ -427,10 +333,10 @@ class GeminiNode(ComfyNodeABC):
render_spec,
)
return (output_text or "Empty response from Gemini model...",)
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
class GeminiInputFiles(ComfyNodeABC):
class GeminiInputFiles(IO.ComfyNode):
"""
Loads and formats input files for use with the Gemini API.
@@ -441,7 +347,7 @@ class GeminiInputFiles(ComfyNodeABC):
"""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
def define_schema(cls):
"""
For details about the supported file input types, see:
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
@@ -456,39 +362,37 @@ class GeminiInputFiles(ComfyNodeABC):
]
input_files = sorted(input_files, key=lambda x: x.name)
input_files = [f.name for f in input_files]
return {
"required": {
"file": (
IO.COMBO,
{
"tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.",
"options": input_files,
"default": input_files[0] if input_files else None,
},
return IO.Schema(
node_id="GeminiInputFiles",
display_name="Gemini Input Files",
category="api node/text/Gemini",
description="Loads and prepares input files to include as inputs for Gemini LLM nodes. "
"The files will be read by the Gemini model when generating a response. "
"The contents of the text file count toward the token limit. "
"🛈 TIP: Can be chained together with other Gemini Input File nodes.",
inputs=[
IO.Combo.Input(
"file",
options=input_files,
default=input_files[0] if input_files else None,
tooltip="Input files to include as context for the model. "
"Only accepts text (.txt) and PDF (.pdf) files for now.",
),
},
"optional": {
"GEMINI_INPUT_FILES": (
IO.Custom("GEMINI_INPUT_FILES").Input(
"GEMINI_INPUT_FILES",
{
"tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.",
"default": None,
},
optional=True,
tooltip="An optional additional file(s) to batch together with the file loaded from this node. "
"Allows chaining of input files so that a single message can include multiple input files.",
),
},
}
DESCRIPTION = "Loads and prepares input files to include as inputs for Gemini LLM nodes. The files will be read by the Gemini model when generating a response. The contents of the text file count toward the token limit. 🛈 TIP: Can be chained together with other Gemini Input File nodes."
RETURN_TYPES = ("GEMINI_INPUT_FILES",)
FUNCTION = "prepare_files"
CATEGORY = "api node/text/Gemini"
def create_file_part(self, file_path: str) -> GeminiPart:
mime_type = (
GeminiMimeType.application_pdf
if file_path.endswith(".pdf")
else GeminiMimeType.text_plain
],
outputs=[
IO.Custom("GEMINI_INPUT_FILES").Output(),
],
)
@classmethod
def create_file_part(cls, file_path: str) -> GeminiPart:
mime_type = GeminiMimeType.application_pdf if file_path.endswith(".pdf") else GeminiMimeType.text_plain
# Use base64 string directly, not the data URI
with open(file_path, "rb") as f:
file_content = f.read()
@@ -501,120 +405,95 @@ class GeminiInputFiles(ComfyNodeABC):
)
)
def prepare_files(
self, file: str, GEMINI_INPUT_FILES: list[GeminiPart] = []
) -> tuple[list[GeminiPart]]:
"""
Loads and formats input files for Gemini API.
"""
file_path = folder_paths.get_annotated_filepath(file)
input_file_content = self.create_file_part(file_path)
files = [input_file_content] + GEMINI_INPUT_FILES
return (files,)
class GeminiImage(ComfyNodeABC):
"""
Node to generate text and image responses from a Gemini model.
This node allows users to interact with Google's Gemini AI models, providing
multimodal inputs (text, images, files) to generate coherent
text and image responses. The node works with the latest Gemini models, handling the
API communication and response parsing.
"""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Text prompt for generation",
},
),
"model": (
IO.COMBO,
{
"tooltip": "The Gemini model to use for generating responses.",
"options": [model.value for model in GeminiImageModel],
"default": GeminiImageModel.gemini_2_5_flash_image.value,
},
),
"seed": (
IO.INT,
{
"default": 42,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
},
),
},
"optional": {
"images": (
IO.IMAGE,
{
"default": None,
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
},
),
"files": (
"GEMINI_INPUT_FILES",
{
"default": None,
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
},
),
# TODO: later we can add this parameter later
# "n": (
# IO.INT,
# {
# "default": 1,
# "min": 1,
# "max": 8,
# "step": 1,
# "display": "number",
# "tooltip": "How many images to generate",
# },
# ),
"aspect_ratio": (
IO.COMBO,
{
"tooltip": "Defaults to matching the output image size to that of your input image, or otherwise generates 1:1 squares.",
"options": ["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
"default": "auto",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
def execute(cls, file: str, GEMINI_INPUT_FILES: Optional[list[GeminiPart]] = None) -> IO.NodeOutput:
"""Loads and formats input files for Gemini API."""
if GEMINI_INPUT_FILES is None:
GEMINI_INPUT_FILES = []
file_path = folder_paths.get_annotated_filepath(file)
input_file_content = cls.create_file_part(file_path)
return IO.NodeOutput([input_file_content] + GEMINI_INPUT_FILES)
RETURN_TYPES = (IO.IMAGE, IO.STRING)
FUNCTION = "api_call"
CATEGORY = "api node/image/Gemini"
DESCRIPTION = "Edit images synchronously via Google API."
API_NODE = True
async def api_call(
self,
class GeminiImage(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GeminiImageNode",
display_name="Google Gemini Image",
category="api node/image/Gemini",
description="Edit images synchronously via Google API.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
tooltip="Text prompt for generation",
default="",
),
IO.Combo.Input(
"model",
options=GeminiImageModel,
default=GeminiImageModel.gemini_2_5_flash_image,
tooltip="The Gemini model to use for generating responses.",
),
IO.Int.Input(
"seed",
default=42,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="When seed is fixed to a specific value, the model makes a best effort to provide "
"the same response for repeated requests. Deterministic output isn't guaranteed. "
"Also, changing the model or parameter settings, such as the temperature, "
"can cause variations in the response even when you use the same seed value. "
"By default, a random seed value is used.",
),
IO.Image.Input(
"images",
optional=True,
tooltip="Optional image(s) to use as context for the model. "
"To include multiple images, you can use the Batch Images node.",
),
IO.Custom("GEMINI_INPUT_FILES").Input(
"files",
optional=True,
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
IO.Combo.Input(
"aspect_ratio",
options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
default="auto",
tooltip="Defaults to matching the output image size to that of your input image, "
"or otherwise generates 1:1 squares.",
optional=True,
),
],
outputs=[
IO.Image.Output(),
IO.String.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt: str,
model: GeminiImageModel,
images: Optional[IO.IMAGE] = None,
model: str,
seed: int,
images: Optional[torch.Tensor] = None,
files: Optional[list[GeminiPart]] = None,
n=1,
aspect_ratio: str = "auto",
unique_id: Optional[str] = None,
**kwargs,
):
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
parts: list[GeminiPart] = [create_text_part(prompt)]
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
if not aspect_ratio:
aspect_ratio = "auto" # for backward compatability with old workflows; to-do remove this in December
@@ -626,29 +505,27 @@ class GeminiImage(ComfyNodeABC):
if files is not None:
parts.extend(files)
response = await SynchronousOperation(
endpoint=get_gemini_image_endpoint(model),
request=GeminiImageGenerateContentRequest(
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
data=GeminiImageGenerateContentRequest(
contents=[
GeminiContent(
role="user",
parts=parts,
),
GeminiContent(role="user", parts=parts),
],
generationConfig=GeminiImageGenerationConfig(
responseModalities=["TEXT","IMAGE"],
responseModalities=["TEXT", "IMAGE"],
imageConfig=None if aspect_ratio == "auto" else image_config,
)
),
),
auth_kwargs=kwargs,
).execute()
response_model=GeminiGenerateContentResponse,
)
output_image = get_image_from_response(response)
output_text = get_text_from_response(response)
if unique_id and output_text:
if output_text:
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
render_spec = {
"node_id": unique_id,
"node_id": cls.hidden.unique_id,
"component": "ChatHistoryWidget",
"props": {
"history": json.dumps(
@@ -669,17 +546,18 @@ class GeminiImage(ComfyNodeABC):
)
output_text = output_text or "Empty response from Gemini model..."
return (output_image, output_text,)
return IO.NodeOutput(output_image, output_text)
NODE_CLASS_MAPPINGS = {
"GeminiNode": GeminiNode,
"GeminiImageNode": GeminiImage,
"GeminiInputFiles": GeminiInputFiles,
}
class GeminiExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
GeminiNode,
GeminiImage,
GeminiInputFiles,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"GeminiNode": "Google Gemini",
"GeminiImageNode": "Google Gemini Image",
"GeminiInputFiles": "Gemini Input Files",
}
async def comfy_entrypoint() -> GeminiExtension:
return GeminiExtension()

View File

@@ -1,6 +1,6 @@
from io import BytesIO
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.latest import IO, ComfyExtension
from PIL import Image
import numpy as np
import torch
@@ -11,19 +11,13 @@ from comfy_api_nodes.apis import (
IdeogramV3Request,
IdeogramV3EditRequest,
)
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_bytesio,
bytesio_to_image_tensor,
download_url_as_bytesio,
resize_mask_to_image,
sync_op,
)
from server import PromptServer
V1_V1_RES_MAP = {
"Auto":"AUTO",
@@ -220,7 +214,7 @@ async def download_and_process_images(image_urls):
for image_url in image_urls:
# Using functions from apinode_utils.py to handle downloading and processing
image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO
image_bytesio = await download_url_as_bytesio(image_url) # Download image content to BytesIO
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
image_tensors.append(img_tensor)
@@ -233,19 +227,6 @@ async def download_and_process_images(image_urls):
return stacked_tensors
def display_image_urls_on_node(image_urls, node_id):
if node_id and image_urls:
if len(image_urls) == 1:
PromptServer.instance.send_progress_text(
f"Generated Image URL:\n{image_urls[0]}", node_id
)
else:
urls_text = "Generated Image URLs:\n" + "\n".join(
f"{i+1}. {url}" for i, url in enumerate(image_urls)
)
PromptServer.instance.send_progress_text(urls_text, node_id)
class IdeogramV1(IO.ComfyNode):
@classmethod
@@ -334,44 +315,30 @@ class IdeogramV1(IO.ComfyNode):
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
model = "V_1_TURBO" if turbo else "V_1"
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
method=HttpMethod.POST,
request_model=IdeogramGenerateRequest,
response_model=IdeogramGenerateResponse,
),
request=IdeogramGenerateRequest(
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
response_model=IdeogramGenerateResponse,
data=IdeogramGenerateRequest(
image_request=ImageRequest(
prompt=prompt,
model=model,
num_images=num_images,
seed=seed,
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
magic_prompt_option=(
magic_prompt_option if magic_prompt_option != "AUTO" else None
),
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
negative_prompt=negative_prompt if negative_prompt else None,
)
),
auth_kwargs=auth,
max_retries=1,
)
response = await operation.execute()
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return IO.NodeOutput(await download_and_process_images(image_urls))
@@ -500,18 +467,11 @@ class IdeogramV2(IO.ComfyNode):
else:
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
method=HttpMethod.POST,
request_model=IdeogramGenerateRequest,
response_model=IdeogramGenerateResponse,
),
request=IdeogramGenerateRequest(
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
response_model=IdeogramGenerateResponse,
data=IdeogramGenerateRequest(
image_request=ImageRequest(
prompt=prompt,
model=model,
@@ -519,28 +479,20 @@ class IdeogramV2(IO.ComfyNode):
seed=seed,
aspect_ratio=final_aspect_ratio,
resolution=final_resolution,
magic_prompt_option=(
magic_prompt_option if magic_prompt_option != "AUTO" else None
),
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
style_type=style_type if style_type != "NONE" else None,
negative_prompt=negative_prompt if negative_prompt else None,
color_palette=color_palette if color_palette else None,
)
),
auth_kwargs=auth,
max_retries=1,
)
response = await operation.execute()
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return IO.NodeOutput(await download_and_process_images(image_urls))
@@ -656,10 +608,6 @@ class IdeogramV3(IO.ComfyNode):
character_image=None,
character_mask=None,
):
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
if rendering_speed == "BALANCED": # for backward compatibility
rendering_speed = "DEFAULT"
@@ -694,9 +642,6 @@ class IdeogramV3(IO.ComfyNode):
# Check if both image and mask are provided for editing mode
if image is not None and mask is not None:
# Edit mode
path = "/proxy/ideogram/ideogram-v3/edit"
# Process image and mask
input_tensor = image.squeeze().cpu()
# Resize mask to match image dimension
@@ -749,27 +694,20 @@ class IdeogramV3(IO.ComfyNode):
if character_mask_binary:
files["character_mask_binary"] = character_mask_binary
# Execute the operation for edit mode
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=IdeogramV3EditRequest,
response_model=IdeogramGenerateResponse,
),
request=edit_request,
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/ideogram/ideogram-v3/edit", method="POST"),
response_model=IdeogramGenerateResponse,
data=edit_request,
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
max_retries=1,
)
elif image is not None or mask is not None:
# If only one of image or mask is provided, raise an error
raise Exception("Ideogram V3 image editing requires both an image AND a mask")
else:
# Generation mode
path = "/proxy/ideogram/ideogram-v3/generate"
# Create generation request
gen_request = IdeogramV3Request(
prompt=prompt,
@@ -800,32 +738,22 @@ class IdeogramV3(IO.ComfyNode):
if files:
gen_request.style_type = "AUTO"
# Execute the operation for generation mode
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=IdeogramV3Request,
response_model=IdeogramGenerateResponse,
),
request=gen_request,
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/ideogram/ideogram-v3/generate", method="POST"),
response_model=IdeogramGenerateResponse,
data=gen_request,
files=files if files else None,
content_type="multipart/form-data",
auth_kwargs=auth,
max_retries=1,
)
# Execute the operation and process response
response = await operation.execute()
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return IO.NodeOutput(await download_and_process_images(image_urls))
@@ -838,5 +766,6 @@ class IdeogramExtension(ComfyExtension):
IdeogramV3,
]
async def comfy_entrypoint() -> IdeogramExtension:
return IdeogramExtension()

View File

@@ -5,8 +5,7 @@ For source of truth on the allowed permutations of request fields, please refere
"""
from __future__ import annotations
from typing import Optional, TypeVar, Any
from collections.abc import Callable
from typing import Optional, TypeVar
import math
import logging
@@ -15,7 +14,6 @@ from typing_extensions import override
import torch
from comfy_api_nodes.apis import (
KlingTaskStatus,
KlingCameraControl,
KlingCameraConfig,
KlingCameraControlType,
@@ -52,26 +50,20 @@ from comfy_api_nodes.apis import (
KlingCharacterEffectModelName,
KlingSingleImageEffectModelName,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
tensor_to_base64_string,
download_url_to_video_output,
upload_video_to_comfyapi,
upload_audio_to_comfyapi,
download_url_to_image_tensor,
validate_string,
)
from comfy_api_nodes.util.validation_utils import (
from comfy_api_nodes.util import (
validate_image_dimensions,
validate_image_aspect_ratio,
validate_video_dimensions,
validate_video_duration,
tensor_to_base64_string,
validate_string,
upload_audio_to_comfyapi,
download_url_to_image_tensor,
upload_video_to_comfyapi,
download_url_to_video_output,
sync_op,
ApiEndpoint,
poll_op,
)
from comfy_api.input_impl import VideoFromFile
from comfy_api.input.basic_types import AudioInput
@@ -214,34 +206,6 @@ VOICES_CONFIG = {
}
async def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, R],
result_url_extractor: Optional[Callable[[R], str]] = None,
estimated_duration: Optional[int] = None,
node_id: Optional[str] = None,
) -> R:
"""Polls the Kling API endpoint until the task reaches a terminal state, then returns the response."""
return await PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[
KlingTaskStatus.succeed.value,
],
failed_statuses=[KlingTaskStatus.failed.value],
status_extractor=lambda response: (
response.data.task_status.value
if response.data and response.data.task_status
else None
),
auth_kwargs=auth_kwargs,
result_url_extractor=result_url_extractor,
estimated_duration=estimated_duration,
node_id=node_id,
poll_interval=16.0,
max_poll_attempts=256,
).execute()
def is_valid_camera_control_configs(configs: list[float]) -> bool:
"""Verifies that at least one camera control configuration is non-zero."""
return any(not math.isclose(value, 0.0) for value in configs)
@@ -318,7 +282,7 @@ def validate_input_image(image: torch.Tensor) -> None:
See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo
"""
validate_image_dimensions(image, min_width=300, min_height=300)
validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5)
validate_image_aspect_ratio(image, (1, 2.5), (2.5, 1))
def get_video_from_response(response) -> KlingVideoResult:
@@ -377,8 +341,7 @@ async def image_result_to_node_output(
async def execute_text2video(
auth_kwargs: dict[str, str],
node_id: str,
cls: type[IO.ComfyNode],
prompt: str,
negative_prompt: str,
cfg_scale: float,
@@ -389,14 +352,11 @@ async def execute_text2video(
camera_control: Optional[KlingCameraControl] = None,
) -> IO.NodeOutput:
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_TEXT_TO_VIDEO,
method=HttpMethod.POST,
request_model=KlingText2VideoRequest,
response_model=KlingText2VideoResponse,
),
request=KlingText2VideoRequest(
task_creation_response = await sync_op(
cls,
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
response_model=KlingText2VideoResponse,
data=KlingText2VideoRequest(
prompt=prompt if prompt else None,
negative_prompt=negative_prompt if negative_prompt else None,
duration=KlingVideoGenDuration(duration),
@@ -406,24 +366,17 @@ async def execute_text2video(
aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio),
camera_control=camera_control,
),
auth_kwargs=auth_kwargs,
)
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_TEXT_TO_VIDEO}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=KlingText2VideoResponse,
),
result_url_extractor=get_video_url_from_response,
final_response = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_TEXT_TO_VIDEO}/{task_id}"),
response_model=KlingText2VideoResponse,
estimated_duration=AVERAGE_DURATION_T2V,
node_id=node_id,
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
)
validate_video_result_response(final_response)
@@ -432,8 +385,7 @@ async def execute_text2video(
async def execute_image2video(
auth_kwargs: dict[str, str],
node_id: str,
cls: type[IO.ComfyNode],
start_frame: torch.Tensor,
prompt: str,
negative_prompt: str,
@@ -455,14 +407,11 @@ async def execute_image2video(
if model_mode == "std" and model_name == KlingVideoGenModelName.kling_v2_5_turbo.value:
model_mode = "pro" # October 5: currently "std" mode is not supported for this model
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST,
request_model=KlingImage2VideoRequest,
response_model=KlingImage2VideoResponse,
),
request=KlingImage2VideoRequest(
task_creation_response = await sync_op(
cls,
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
response_model=KlingImage2VideoResponse,
data=KlingImage2VideoRequest(
model_name=KlingVideoGenModelName(model_name),
image=tensor_to_base64_string(start_frame),
image_tail=(
@@ -477,24 +426,17 @@ async def execute_image2video(
duration=KlingVideoGenDuration(duration),
camera_control=camera_control,
),
auth_kwargs=auth_kwargs,
)
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}",
method=HttpMethod.GET,
request_model=KlingImage2VideoRequest,
response_model=KlingImage2VideoResponse,
),
result_url_extractor=get_video_url_from_response,
final_response = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
response_model=KlingImage2VideoResponse,
estimated_duration=AVERAGE_DURATION_I2V,
node_id=node_id,
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
)
validate_video_result_response(final_response)
@@ -503,8 +445,7 @@ async def execute_image2video(
async def execute_video_effect(
auth_kwargs: dict[str, str],
node_id: str,
cls: type[IO.ComfyNode],
dual_character: bool,
effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene,
model_name: str,
@@ -530,35 +471,25 @@ async def execute_video_effect(
duration=duration,
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_VIDEO_EFFECTS,
method=HttpMethod.POST,
request_model=KlingVideoEffectsRequest,
response_model=KlingVideoEffectsResponse,
),
request=KlingVideoEffectsRequest(
task_creation_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=PATH_VIDEO_EFFECTS, method="POST"),
response_model=KlingVideoEffectsResponse,
data=KlingVideoEffectsRequest(
effect_scene=effect_scene,
input=request_input_field,
),
auth_kwargs=auth_kwargs,
)
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_VIDEO_EFFECTS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=KlingVideoEffectsResponse,
),
result_url_extractor=get_video_url_from_response,
final_response = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_VIDEO_EFFECTS}/{task_id}"),
response_model=KlingVideoEffectsResponse,
estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS,
node_id=node_id,
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
)
validate_video_result_response(final_response)
@@ -567,8 +498,7 @@ async def execute_video_effect(
async def execute_lipsync(
auth_kwargs: dict[str, str],
node_id: str,
cls: type[IO.ComfyNode],
video: VideoInput,
audio: Optional[AudioInput] = None,
voice_language: Optional[str] = None,
@@ -583,24 +513,21 @@ async def execute_lipsync(
validate_video_duration(video, 2, 10)
# Upload video to Comfy API and get download URL
video_url = await upload_video_to_comfyapi(video, auth_kwargs=auth_kwargs)
video_url = await upload_video_to_comfyapi(cls, video)
logging.info("Uploaded video to Comfy API. URL: %s", video_url)
# Upload the audio file to Comfy API and get download URL
if audio:
audio_url = await upload_audio_to_comfyapi(audio, auth_kwargs=auth_kwargs)
audio_url = await upload_audio_to_comfyapi(cls, audio)
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
else:
audio_url = None
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_LIP_SYNC,
method=HttpMethod.POST,
request_model=KlingLipSyncRequest,
response_model=KlingLipSyncResponse,
),
request=KlingLipSyncRequest(
task_creation_response = await sync_op(
cls,
ApiEndpoint(PATH_LIP_SYNC, "POST"),
response_model=KlingLipSyncResponse,
data=KlingLipSyncRequest(
input=KlingLipSyncInputObject(
video_url=video_url,
mode=model_mode,
@@ -612,24 +539,17 @@ async def execute_lipsync(
voice_id=voice_id,
),
),
auth_kwargs=auth_kwargs,
)
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_LIP_SYNC}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=KlingLipSyncResponse,
),
result_url_extractor=get_video_url_from_response,
final_response = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_LIP_SYNC}/{task_id}"),
response_model=KlingLipSyncResponse,
estimated_duration=AVERAGE_DURATION_LIP_SYNC,
node_id=node_id,
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
)
validate_video_result_response(final_response)
@@ -807,11 +727,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
) -> IO.NodeOutput:
model_mode, duration, model_name = MODE_TEXT2VIDEO[mode]
return await execute_text2video(
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
cls,
prompt=prompt,
negative_prompt=negative_prompt,
cfg_scale=cfg_scale,
@@ -872,11 +788,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode):
camera_control: Optional[KlingCameraControl] = None,
) -> IO.NodeOutput:
return await execute_text2video(
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
cls,
model_name=KlingVideoGenModelName.kling_v1,
cfg_scale=cfg_scale,
model_mode=KlingVideoGenMode.std,
@@ -944,11 +856,7 @@ class KlingImage2VideoNode(IO.ComfyNode):
end_frame: Optional[torch.Tensor] = None,
) -> IO.NodeOutput:
return await execute_image2video(
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
cls,
start_frame=start_frame,
prompt=prompt,
negative_prompt=negative_prompt,
@@ -1017,11 +925,7 @@ class KlingCameraControlI2VNode(IO.ComfyNode):
camera_control: KlingCameraControl,
) -> IO.NodeOutput:
return await execute_image2video(
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
cls,
model_name=KlingVideoGenModelName.kling_v1_5,
start_frame=start_frame,
cfg_scale=cfg_scale,
@@ -1097,11 +1001,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
) -> IO.NodeOutput:
mode, duration, model_name = MODE_START_END_FRAME[mode]
return await execute_image2video(
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
cls,
prompt=prompt,
negative_prompt=negative_prompt,
model_name=model_name,
@@ -1162,41 +1062,27 @@ class KlingVideoExtendNode(IO.ComfyNode):
video_id: str,
) -> IO.NodeOutput:
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_VIDEO_EXTEND,
method=HttpMethod.POST,
request_model=KlingVideoExtendRequest,
response_model=KlingVideoExtendResponse,
),
request=KlingVideoExtendRequest(
task_creation_response = await sync_op(
cls,
ApiEndpoint(path=PATH_VIDEO_EXTEND, method="POST"),
response_model=KlingVideoExtendResponse,
data=KlingVideoExtendRequest(
prompt=prompt if prompt else None,
negative_prompt=negative_prompt if negative_prompt else None,
cfg_scale=cfg_scale,
video_id=video_id,
),
auth_kwargs=auth,
)
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = await poll_until_finished(
auth,
ApiEndpoint(
path=f"{PATH_VIDEO_EXTEND}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=KlingVideoExtendResponse,
),
result_url_extractor=get_video_url_from_response,
final_response = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_VIDEO_EXTEND}/{task_id}"),
response_model=KlingVideoExtendResponse,
estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND,
node_id=cls.hidden.unique_id,
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
)
validate_video_result_response(final_response)
@@ -1259,11 +1145,7 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode):
duration: KlingVideoGenDuration,
) -> IO.NodeOutput:
video, _, duration = await execute_video_effect(
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
cls,
dual_character=True,
effect_scene=effect_scene,
model_name=model_name,
@@ -1324,11 +1206,7 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
return IO.NodeOutput(
*(
await execute_video_effect(
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
cls,
dual_character=False,
effect_scene=effect_scene,
model_name=model_name,
@@ -1379,11 +1257,7 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode):
voice_language: str,
) -> IO.NodeOutput:
return await execute_lipsync(
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
cls,
video=video,
audio=audio,
voice_language=voice_language,
@@ -1445,11 +1319,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode):
) -> IO.NodeOutput:
voice_id, voice_language = VOICES_CONFIG[voice]
return await execute_lipsync(
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
cls,
video=video,
text=text,
voice_language=voice_language,
@@ -1496,40 +1366,26 @@ class KlingVirtualTryOnNode(IO.ComfyNode):
cloth_image: torch.Tensor,
model_name: KlingVirtualTryOnModelName,
) -> IO.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_VIRTUAL_TRY_ON,
method=HttpMethod.POST,
request_model=KlingVirtualTryOnRequest,
response_model=KlingVirtualTryOnResponse,
),
request=KlingVirtualTryOnRequest(
task_creation_response = await sync_op(
cls,
ApiEndpoint(path=PATH_VIRTUAL_TRY_ON, method="POST"),
response_model=KlingVirtualTryOnResponse,
data=KlingVirtualTryOnRequest(
human_image=tensor_to_base64_string(human_image),
cloth_image=tensor_to_base64_string(cloth_image),
model_name=model_name,
),
auth_kwargs=auth,
)
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = await poll_until_finished(
auth,
ApiEndpoint(
path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=KlingVirtualTryOnResponse,
),
result_url_extractor=get_images_urls_from_response,
final_response = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}"),
response_model=KlingVirtualTryOnResponse,
estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON,
node_id=cls.hidden.unique_id,
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
)
validate_image_result_response(final_response)
@@ -1625,18 +1481,11 @@ class KlingImageGenerationNode(IO.ComfyNode):
else:
image = tensor_to_base64_string(image)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_IMAGE_GENERATIONS,
method=HttpMethod.POST,
request_model=KlingImageGenerationsRequest,
response_model=KlingImageGenerationsResponse,
),
request=KlingImageGenerationsRequest(
task_creation_response = await sync_op(
cls,
ApiEndpoint(path=PATH_IMAGE_GENERATIONS, method="POST"),
response_model=KlingImageGenerationsResponse,
data=KlingImageGenerationsRequest(
model_name=model_name,
prompt=prompt,
negative_prompt=negative_prompt,
@@ -1647,24 +1496,17 @@ class KlingImageGenerationNode(IO.ComfyNode):
n=n,
aspect_ratio=aspect_ratio,
),
auth_kwargs=auth,
)
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = await poll_until_finished(
auth,
ApiEndpoint(
path=f"{PATH_IMAGE_GENERATIONS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=KlingImageGenerationsResponse,
),
result_url_extractor=get_images_urls_from_response,
final_response = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_IMAGE_GENERATIONS}/{task_id}"),
response_model=KlingImageGenerationsResponse,
estimated_duration=AVERAGE_DURATION_IMAGE_GEN,
node_id=cls.hidden.unique_id,
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
)
validate_image_result_response(final_response)

View File

@@ -0,0 +1,199 @@
from io import BytesIO
from typing import Optional
import torch
from pydantic import BaseModel, Field
from typing_extensions import override
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.util import (
ApiEndpoint,
get_number_of_images,
sync_op_raw,
upload_images_to_comfyapi,
validate_string,
)
MODELS_MAP = {
"LTX-2 (Pro)": "ltx-2-pro",
"LTX-2 (Fast)": "ltx-2-fast",
}
class ExecuteTaskRequest(BaseModel):
prompt: str = Field(...)
model: str = Field(...)
duration: int = Field(...)
resolution: str = Field(...)
fps: Optional[int] = Field(25)
generate_audio: Optional[bool] = Field(True)
image_uri: Optional[str] = Field(None)
class TextToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="LtxvApiTextToVideo",
display_name="LTXV Text To Video",
category="api node/video/LTXV",
description="Professional-quality videos with customizable duration and resolution.",
inputs=[
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
IO.String.Input(
"prompt",
multiline=True,
default="",
),
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
IO.Combo.Input(
"resolution",
options=[
"1920x1080",
"2560x1440",
"3840x2160",
],
),
IO.Combo.Input("fps", options=[25, 50], default=25),
IO.Boolean.Input(
"generate_audio",
default=False,
optional=True,
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
duration: int,
resolution: str,
fps: int = 25,
generate_audio: bool = False,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=10000)
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
raise ValueError(
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
)
response = await sync_op_raw(
cls,
ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"),
data=ExecuteTaskRequest(
prompt=prompt,
model=MODELS_MAP[model],
duration=duration,
resolution=resolution,
fps=fps,
generate_audio=generate_audio,
),
as_binary=True,
max_retries=1,
)
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
class ImageToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="LtxvApiImageToVideo",
display_name="LTXV Image To Video",
category="api node/video/LTXV",
description="Professional-quality videos with customizable duration and resolution based on start image.",
inputs=[
IO.Image.Input("image", tooltip="First frame to be used for the video."),
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
IO.String.Input(
"prompt",
multiline=True,
default="",
),
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
IO.Combo.Input(
"resolution",
options=[
"1920x1080",
"2560x1440",
"3840x2160",
],
),
IO.Combo.Input("fps", options=[25, 50], default=25),
IO.Boolean.Input(
"generate_audio",
default=False,
optional=True,
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
model: str,
prompt: str,
duration: int,
resolution: str,
fps: int = 25,
generate_audio: bool = False,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=10000)
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
raise ValueError(
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
)
if get_number_of_images(image) != 1:
raise ValueError("Currently only one input image is supported.")
response = await sync_op_raw(
cls,
ApiEndpoint("/proxy/ltx/v1/image-to-video", "POST"),
data=ExecuteTaskRequest(
image_uri=(await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0],
prompt=prompt,
model=MODELS_MAP[model],
duration=duration,
resolution=resolution,
fps=fps,
generate_audio=generate_audio,
),
as_binary=True,
max_retries=1,
)
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
class LtxvApiExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
TextToVideoNode,
ImageToVideoNode,
]
async def comfy_entrypoint() -> LtxvApiExtension:
return LtxvApiExtension()

View File

@@ -1,69 +1,51 @@
from __future__ import annotations
from inspect import cleandoc
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.luma_api import (
LumaImageModel,
LumaVideoModel,
LumaVideoOutputResolution,
LumaVideoModelOutputDuration,
LumaAspectRatio,
LumaState,
LumaImageGenerationRequest,
LumaGenerationRequest,
LumaGeneration,
LumaCharacterRef,
LumaModifyImageRef,
LumaConceptChain,
LumaGeneration,
LumaGenerationRequest,
LumaImageGenerationRequest,
LumaImageIdentity,
LumaImageModel,
LumaImageReference,
LumaIO,
LumaKeyframes,
LumaModifyImageRef,
LumaReference,
LumaReferenceChain,
LumaImageReference,
LumaKeyframes,
LumaConceptChain,
LumaIO,
LumaVideoModel,
LumaVideoModelOutputDuration,
LumaVideoOutputResolution,
get_luma_concepts,
)
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_image_tensor,
download_url_to_video_output,
poll_op,
sync_op,
upload_images_to_comfyapi,
process_image_response,
validate_string,
)
from server import PromptServer
import aiohttp
import torch
from io import BytesIO
LUMA_T2V_AVERAGE_DURATION = 105
LUMA_I2V_AVERAGE_DURATION = 100
def image_result_url_extractor(response: LumaGeneration):
return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None
def video_result_url_extractor(response: LumaGeneration):
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
class LumaReferenceNode(IO.ComfyNode):
"""
Holds an image and weight for use with Luma Generate Image node.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaReferenceNode",
display_name="Luma Reference",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
description="Holds an image and weight for use with Luma Generate Image node.",
inputs=[
IO.Image.Input(
"image",
@@ -83,17 +65,10 @@ class LumaReferenceNode(IO.ComfyNode):
),
],
outputs=[IO.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
)
@classmethod
def execute(
cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
) -> IO.NodeOutput:
def execute(cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None) -> IO.NodeOutput:
if luma_ref is not None:
luma_ref = luma_ref.clone()
else:
@@ -103,17 +78,13 @@ class LumaReferenceNode(IO.ComfyNode):
class LumaConceptsNode(IO.ComfyNode):
"""
Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaConceptsNode",
display_name="Luma Concepts",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.",
inputs=[
IO.Combo.Input(
"concept1",
@@ -138,11 +109,6 @@ class LumaConceptsNode(IO.ComfyNode):
),
],
outputs=[IO.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
)
@classmethod
@@ -161,17 +127,13 @@ class LumaConceptsNode(IO.ComfyNode):
class LumaImageGenerationNode(IO.ComfyNode):
"""
Generates images synchronously based on prompt and aspect ratio.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaImageNode",
display_name="Luma Text to Image",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
description="Generates images synchronously based on prompt and aspect ratio.",
inputs=[
IO.String.Input(
"prompt",
@@ -237,45 +199,30 @@ class LumaImageGenerationNode(IO.ComfyNode):
aspect_ratio: str,
seed,
style_image_weight: float,
image_luma_ref: LumaReferenceChain = None,
style_image: torch.Tensor = None,
character_image: torch.Tensor = None,
image_luma_ref: Optional[LumaReferenceChain] = None,
style_image: Optional[torch.Tensor] = None,
character_image: Optional[torch.Tensor] = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=3)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# handle image_luma_ref
api_image_ref = None
if image_luma_ref is not None:
api_image_ref = await cls._convert_luma_refs(
image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs,
)
api_image_ref = await cls._convert_luma_refs(image_luma_ref, max_refs=4)
# handle style_luma_ref
api_style_ref = None
if style_image is not None:
api_style_ref = await cls._convert_style_image(
style_image, weight=style_image_weight, auth_kwargs=auth_kwargs,
)
api_style_ref = await cls._convert_style_image(style_image, weight=style_image_weight)
# handle character_ref images
character_ref = None
if character_image is not None:
download_urls = await upload_images_to_comfyapi(
character_image, max_images=4, auth_kwargs=auth_kwargs,
)
character_ref = LumaCharacterRef(
identity0=LumaImageIdentity(images=download_urls)
)
download_urls = await upload_images_to_comfyapi(cls, character_image, max_images=4)
character_ref = LumaCharacterRef(identity0=LumaImageIdentity(images=download_urls))
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations/image",
method=HttpMethod.POST,
request_model=LumaImageGenerationRequest,
response_model=LumaGeneration,
),
request=LumaImageGenerationRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
response_model=LumaGeneration,
data=LumaImageGenerationRequest(
prompt=prompt,
model=model,
aspect_ratio=aspect_ratio,
@@ -283,41 +230,21 @@ class LumaImageGenerationNode(IO.ComfyNode):
style_ref=api_style_ref,
character_ref=character_ref,
),
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
response_model=LumaGeneration,
status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=cls.hidden.unique_id,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return IO.NodeOutput(img)
return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
@classmethod
async def _convert_luma_refs(
cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
):
async def _convert_luma_refs(cls, luma_ref: LumaReferenceChain, max_refs: int):
luma_urls = []
ref_count = 0
for ref in luma_ref.refs:
download_urls = await upload_images_to_comfyapi(
ref.image, max_images=1, auth_kwargs=auth_kwargs
)
download_urls = await upload_images_to_comfyapi(cls, ref.image, max_images=1)
luma_urls.append(download_urls[0])
ref_count += 1
if ref_count >= max_refs:
@@ -325,27 +252,19 @@ class LumaImageGenerationNode(IO.ComfyNode):
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
@classmethod
async def _convert_style_image(
cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
):
chain = LumaReferenceChain(
first_ref=LumaReference(image=style_image, weight=weight)
)
return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
async def _convert_style_image(cls, style_image: torch.Tensor, weight: float):
chain = LumaReferenceChain(first_ref=LumaReference(image=style_image, weight=weight))
return await cls._convert_luma_refs(chain, max_refs=1)
class LumaImageModifyNode(IO.ComfyNode):
"""
Modifies images synchronously based on prompt and aspect ratio.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaImageModifyNode",
display_name="Luma Image to Image",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
description="Modifies images synchronously based on prompt and aspect ratio.",
inputs=[
IO.Image.Input(
"image",
@@ -395,68 +314,37 @@ class LumaImageModifyNode(IO.ComfyNode):
image_weight: float,
seed,
) -> IO.NodeOutput:
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# first, upload image
download_urls = await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=auth_kwargs,
)
download_urls = await upload_images_to_comfyapi(cls, image, max_images=1)
image_url = download_urls[0]
# next, make Luma call with download url provided
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations/image",
method=HttpMethod.POST,
request_model=LumaImageGenerationRequest,
response_model=LumaGeneration,
),
request=LumaImageGenerationRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
response_model=LumaGeneration,
data=LumaImageGenerationRequest(
prompt=prompt,
model=model,
modify_image_ref=LumaModifyImageRef(
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
url=image_url, weight=round(max(min(1.0 - image_weight, 0.98), 0.0), 2)
),
),
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
response_model=LumaGeneration,
status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=cls.hidden.unique_id,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return IO.NodeOutput(img)
return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
class LumaTextToVideoGenerationNode(IO.ComfyNode):
"""
Generates videos synchronously based on prompt and output_size.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaVideoNode",
display_name="Luma Text to Video",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos synchronously based on prompt and output_size.",
inputs=[
IO.String.Input(
"prompt",
@@ -498,7 +386,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
"luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True,
)
),
],
outputs=[IO.Video.Output()],
hidden=[
@@ -519,24 +407,17 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
duration: str,
loop: bool,
seed,
luma_concepts: LumaConceptChain = None,
luma_concepts: Optional[LumaConceptChain] = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=3)
duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations",
method=HttpMethod.POST,
request_model=LumaGenerationRequest,
response_model=LumaGeneration,
),
request=LumaGenerationRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/luma/generations", method="POST"),
response_model=LumaGeneration,
data=LumaGenerationRequest(
prompt=prompt,
model=model,
resolution=resolution,
@@ -545,47 +426,25 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
loop=loop,
concepts=luma_concepts.create_api_model() if luma_concepts else None,
),
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
response_model=LumaGeneration,
status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=cls.hidden.unique_id,
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
class LumaImageToVideoGenerationNode(IO.ComfyNode):
"""
Generates videos synchronously based on prompt, input images, and output_size.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaImageToVideoNode",
display_name="Luma Image to Video",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos synchronously based on prompt, input images, and output_size.",
inputs=[
IO.String.Input(
"prompt",
@@ -637,7 +496,7 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
"luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True,
)
),
],
outputs=[IO.Video.Output()],
hidden=[
@@ -662,25 +521,15 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
luma_concepts: LumaConceptChain = None,
) -> IO.NodeOutput:
if first_image is None and last_image is None:
raise Exception(
"At least one of first_image and last_image requires an input."
)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs)
raise Exception("At least one of first_image and last_image requires an input.")
keyframes = await cls._convert_to_keyframes(first_image, last_image)
duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations",
method=HttpMethod.POST,
request_model=LumaGenerationRequest,
response_model=LumaGeneration,
),
request=LumaGenerationRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/luma/generations", method="POST"),
response_model=LumaGeneration,
data=LumaGenerationRequest(
prompt=prompt,
model=model,
aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason
@@ -690,54 +539,31 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
keyframes=keyframes,
concepts=luma_concepts.create_api_model() if luma_concepts else None,
),
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
response_poll = await poll_op(
cls,
poll_endpoint=ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
response_model=LumaGeneration,
status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=cls.hidden.unique_id,
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
@classmethod
async def _convert_to_keyframes(
cls,
first_image: torch.Tensor = None,
last_image: torch.Tensor = None,
auth_kwargs: Optional[dict[str,str]] = None,
):
if first_image is None and last_image is None:
return None
frame0 = None
frame1 = None
if first_image is not None:
download_urls = await upload_images_to_comfyapi(
first_image, max_images=1, auth_kwargs=auth_kwargs,
)
download_urls = await upload_images_to_comfyapi(cls, first_image, max_images=1)
frame0 = LumaImageReference(type="image", url=download_urls[0])
if last_image is not None:
download_urls = await upload_images_to_comfyapi(
last_image, max_images=1, auth_kwargs=auth_kwargs,
)
download_urls = await upload_images_to_comfyapi(cls, last_image, max_images=1)
frame1 = LumaImageReference(type="image", url=download_urls[0])
return LumaKeyframes(frame0=frame0, frame1=frame1)

View File

@@ -1,71 +1,57 @@
from inspect import cleandoc
from typing import Optional
import logging
import torch
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis import (
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.minimax_api import (
MinimaxFileRetrieveResponse,
MiniMaxModel,
MinimaxTaskResultResponse,
MinimaxVideoGenerationRequest,
MinimaxVideoGenerationResponse,
MinimaxFileRetrieveResponse,
MinimaxTaskResultResponse,
SubjectReferenceItem,
MiniMaxModel,
)
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_bytesio,
download_url_to_video_output,
poll_op,
sync_op,
upload_images_to_comfyapi,
validate_string,
)
from server import PromptServer
I2V_AVERAGE_DURATION = 114
T2V_AVERAGE_DURATION = 234
async def _generate_mm_video(
cls: type[IO.ComfyNode],
*,
auth: dict[str, str],
node_id: str,
prompt_text: str,
seed: int,
model: str,
image: Optional[torch.Tensor] = None, # used for ImageToVideo
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
image: Optional[torch.Tensor] = None, # used for ImageToVideo
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
average_duration: Optional[int] = None,
) -> IO.NodeOutput:
if image is None:
validate_string(prompt_text, field_name="prompt_text")
# upload image, if passed in
image_url = None
if image is not None:
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth))[0]
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
subject_reference = None
if subject is not None:
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=auth))[0]
subject_url = (await upload_images_to_comfyapi(cls, subject, max_images=1))[0]
subject_reference = [SubjectReferenceItem(image=subject_url)]
video_generate_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/video_generation",
method=HttpMethod.POST,
request_model=MinimaxVideoGenerationRequest,
response_model=MinimaxVideoGenerationResponse,
),
request=MinimaxVideoGenerationRequest(
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
response_model=MinimaxVideoGenerationResponse,
data=MinimaxVideoGenerationRequest(
model=MiniMaxModel(model),
prompt=prompt_text,
callback_url=None,
@@ -73,81 +59,50 @@ async def _generate_mm_video(
subject_reference=subject_reference,
prompt_optimizer=None,
),
auth_kwargs=auth,
)
response = await video_generate_operation.execute()
task_id = response.task_id
if not task_id:
raise Exception(f"MiniMax generation failed: {response.base_resp}")
video_generate_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/minimax/query/video_generation",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxTaskResultResponse,
query_params={"task_id": task_id},
),
completed_statuses=["Success"],
failed_statuses=["Fail"],
task_result = await poll_op(
cls,
ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
response_model=MinimaxTaskResultResponse,
status_extractor=lambda x: x.status.value,
estimated_duration=average_duration,
node_id=node_id,
auth_kwargs=auth,
)
task_result = await video_generate_operation.execute()
file_id = task_result.file_id
if file_id is None:
raise Exception("Request was not successful. Missing file ID.")
file_retrieve_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/files/retrieve",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxFileRetrieveResponse,
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=auth,
file_result = await sync_op(
cls,
ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
response_model=MinimaxFileRetrieveResponse,
)
file_result = await file_retrieve_operation.execute()
file_url = file_result.file.download_url
if file_url is None:
raise Exception(
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info("Generated video URL: %s", file_url)
if node_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, node_id)
# Download and return as VideoFromFile
video_io = await download_url_to_bytesio(file_url)
if video_io is None:
error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg)
raise Exception(error_msg)
return IO.NodeOutput(VideoFromFile(video_io))
raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
if file_result.file.backup_download_url:
try:
return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
except Exception: # if we have a second URL to retrieve the result, try again using that one
return IO.NodeOutput(
await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
)
return IO.NodeOutput(await download_url_to_video_output(file_url))
class MinimaxTextToVideoNode(IO.ComfyNode):
"""
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MinimaxTextToVideoNode",
display_name="MiniMax Text to Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos synchronously based on a prompt, and optional parameters.",
inputs=[
IO.String.Input(
"prompt_text",
@@ -189,11 +144,7 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
seed: int = 0,
) -> IO.NodeOutput:
return await _generate_mm_video(
auth={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
cls,
prompt_text=prompt_text,
seed=seed,
model=model,
@@ -204,17 +155,13 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
class MinimaxImageToVideoNode(IO.ComfyNode):
"""
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MinimaxImageToVideoNode",
display_name="MiniMax Image to Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
inputs=[
IO.Image.Input(
"image",
@@ -261,11 +208,7 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
seed: int = 0,
) -> IO.NodeOutput:
return await _generate_mm_video(
auth={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
cls,
prompt_text=prompt_text,
seed=seed,
model=model,
@@ -276,17 +219,13 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
class MinimaxSubjectToVideoNode(IO.ComfyNode):
"""
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MinimaxSubjectToVideoNode",
display_name="MiniMax Subject to Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
inputs=[
IO.Image.Input(
"subject",
@@ -333,11 +272,7 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode):
seed: int = 0,
) -> IO.NodeOutput:
return await _generate_mm_video(
auth={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
cls,
prompt_text=prompt_text,
seed=seed,
model=model,
@@ -348,15 +283,13 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode):
class MinimaxHailuoVideoNode(IO.ComfyNode):
"""Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MinimaxHailuoVideoNode",
display_name="MiniMax Hailuo Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.",
inputs=[
IO.String.Input(
"prompt_text",
@@ -420,10 +353,6 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
resolution: str = "768P",
model: str = "MiniMax-Hailuo-02",
) -> IO.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
if first_frame_image is None:
validate_string(prompt_text, field_name="prompt_text")
@@ -435,16 +364,13 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
# upload image, if passed in
image_url = None
if first_frame_image is not None:
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=auth))[0]
image_url = (await upload_images_to_comfyapi(cls, first_frame_image, max_images=1))[0]
video_generate_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/video_generation",
method=HttpMethod.POST,
request_model=MinimaxVideoGenerationRequest,
response_model=MinimaxVideoGenerationResponse,
),
request=MinimaxVideoGenerationRequest(
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
response_model=MinimaxVideoGenerationResponse,
data=MinimaxVideoGenerationRequest(
model=MiniMaxModel(model),
prompt=prompt_text,
callback_url=None,
@@ -453,67 +379,42 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
duration=duration,
resolution=resolution,
),
auth_kwargs=auth,
)
response = await video_generate_operation.execute()
task_id = response.task_id
if not task_id:
raise Exception(f"MiniMax generation failed: {response.base_resp}")
average_duration = 120 if resolution == "768P" else 240
video_generate_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/minimax/query/video_generation",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxTaskResultResponse,
query_params={"task_id": task_id},
),
completed_statuses=["Success"],
failed_statuses=["Fail"],
task_result = await poll_op(
cls,
ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
response_model=MinimaxTaskResultResponse,
status_extractor=lambda x: x.status.value,
estimated_duration=average_duration,
node_id=cls.hidden.unique_id,
auth_kwargs=auth,
)
task_result = await video_generate_operation.execute()
file_id = task_result.file_id
if file_id is None:
raise Exception("Request was not successful. Missing file ID.")
file_retrieve_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/files/retrieve",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxFileRetrieveResponse,
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=auth,
file_result = await sync_op(
cls,
ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
response_model=MinimaxFileRetrieveResponse,
)
file_result = await file_retrieve_operation.execute()
file_url = file_result.file.download_url
if file_url is None:
raise Exception(
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info("Generated video URL: %s", file_url)
if cls.hidden.unique_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, cls.hidden.unique_id)
raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
video_io = await download_url_to_bytesio(file_url)
if video_io is None:
error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg)
raise Exception(error_msg)
return IO.NodeOutput(VideoFromFile(video_io))
if file_result.file.backup_download_url:
try:
return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
except Exception: # if we have a second URL to retrieve the result, try again using that one
return IO.NodeOutput(
await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
)
return IO.NodeOutput(await download_url_to_video_output(file_url))
class MinimaxExtension(ComfyExtension):

View File

@@ -1,35 +1,31 @@
import logging
from typing import Any, Callable, Optional, TypeVar
from typing import Optional
import torch
from typing_extensions import override
from comfy_api_nodes.util.validation_utils import validate_image_dimensions
from comfy_api.input import VideoInput
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis import (
MoonvalleyTextToVideoRequest,
MoonvalleyPromptResponse,
MoonvalleyTextToVideoInferenceParams,
MoonvalleyTextToVideoRequest,
MoonvalleyVideoToVideoInferenceParams,
MoonvalleyVideoToVideoRequest,
MoonvalleyPromptResponse,
)
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_video_output,
poll_op,
sync_op,
trim_video,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_container_format_is_mp4,
validate_image_dimensions,
validate_string,
)
from comfy_api.input import VideoInput
from comfy_api.latest import ComfyExtension, InputImpl, IO
import av
import io
API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads"
API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts"
API_VIDEO2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/video-to-video"
@@ -51,13 +47,6 @@ MAX_VID_HEIGHT = 10000
MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing
MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000
R = TypeVar("R")
class MoonvalleyApiError(Exception):
"""Base exception for Moonvalley API errors."""
pass
def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool:
@@ -69,64 +58,7 @@ def validate_task_creation_response(response) -> None:
if not is_valid_task_creation_response(response):
error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}"
logging.error(error_msg)
raise MoonvalleyApiError(error_msg)
def get_video_from_response(response):
video = response.output_url
logging.info(
"Moonvalley Marey API: Task %s succeeded. Video URL: %s", response.id, video
)
return video
def get_video_url_from_response(response) -> Optional[str]:
"""Returns the first video url from the Moonvalley video generation task result.
Will not raise an error if the response is not valid.
"""
if response:
return str(get_video_from_response(response))
else:
return None
async def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, R],
result_url_extractor: Optional[Callable[[R], str]] = None,
node_id: Optional[str] = None,
) -> R:
"""Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response."""
return await PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[
"completed",
],
max_poll_attempts=240, # 64 minutes with 16s interval
poll_interval=16.0,
failed_statuses=["error"],
status_extractor=lambda response: (
response.status if response and response.status else None
),
auth_kwargs=auth_kwargs,
result_url_extractor=result_url_extractor,
node_id=node_id,
).execute()
def validate_prompts(
prompt: str, negative_prompt: str, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH
):
"""Verifies that the prompt isn't empty and that neither prompt is too long."""
if not prompt:
raise ValueError("Positive prompt is empty")
if len(prompt) > max_length:
raise ValueError(f"Positive prompt is too long: {len(prompt)} characters")
if negative_prompt and len(negative_prompt) > max_length:
raise ValueError(
f"Negative prompt is too long: {len(negative_prompt)} characters"
)
return True
raise RuntimeError(error_msg)
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
@@ -170,12 +102,8 @@ def _validate_video_dimensions(width: int, height: int) -> None:
}
if (width, height) not in supported_resolutions:
supported_list = ", ".join(
[f"{w}x{h}" for w, h in sorted(supported_resolutions)]
)
raise ValueError(
f"Resolution {width}x{height} not supported. Supported: {supported_list}"
)
supported_list = ", ".join([f"{w}x{h}" for w, h in sorted(supported_resolutions)])
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
@@ -188,7 +116,7 @@ def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
def _validate_minimum_duration(duration: float) -> None:
"""Ensures video is at least 5 seconds long."""
if duration < 5:
raise MoonvalleyApiError("Input video must be at least 5 seconds long.")
raise ValueError("Input video must be at least 5 seconds long.")
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
@@ -198,123 +126,6 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
return video
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
"""
Returns a new VideoInput object trimmed from the beginning to the specified duration,
using av to avoid loading entire video into memory.
Args:
video: Input video to trim
duration_sec: Duration in seconds to keep from the beginning
Returns:
VideoFromFile object that owns the output buffer
"""
output_buffer = io.BytesIO()
input_container = None
output_container = None
try:
# Get the stream source - this avoids loading entire video into memory
# when the source is already a file path
input_source = video.get_stream_source()
# Open containers
input_container = av.open(input_source, mode="r")
output_container = av.open(output_buffer, mode="w", format="mp4")
# Set up output streams for re-encoding
video_stream = None
audio_stream = None
for stream in input_container.streams:
logging.info("Found stream: type=%s, class=%s", stream.type, type(stream))
if isinstance(stream, av.VideoStream):
# Create output video stream with same parameters
video_stream = output_container.add_stream(
"h264", rate=stream.average_rate
)
video_stream.width = stream.width
video_stream.height = stream.height
video_stream.pix_fmt = "yuv420p"
logging.info(
"Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate
)
elif isinstance(stream, av.AudioStream):
# Create output audio stream with same parameters
audio_stream = output_container.add_stream(
"aac", rate=stream.sample_rate
)
audio_stream.sample_rate = stream.sample_rate
audio_stream.layout = stream.layout
logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels)
# Calculate target frame count that's divisible by 16
fps = input_container.streams.video[0].average_rate
estimated_frames = int(duration_sec * fps)
target_frames = (
estimated_frames // 16
) * 16 # Round down to nearest multiple of 16
if target_frames == 0:
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
frame_count = 0
audio_frame_count = 0
# Decode and re-encode video frames
if video_stream:
for frame in input_container.decode(video=0):
if frame_count >= target_frames:
break
# Re-encode frame
for packet in video_stream.encode(frame):
output_container.mux(packet)
frame_count += 1
# Flush encoder
for packet in video_stream.encode():
output_container.mux(packet)
logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames)
# Decode and re-encode audio frames
if audio_stream:
input_container.seek(0) # Reset to beginning for audio
for frame in input_container.decode(audio=0):
if frame.time >= duration_sec:
break
# Re-encode frame
for packet in audio_stream.encode(frame):
output_container.mux(packet)
audio_frame_count += 1
# Flush encoder
for packet in audio_stream.encode():
output_container.mux(packet)
logging.info("Encoded %s audio frames", audio_frame_count)
# Close containers
output_container.close()
input_container.close()
# Return as VideoFromFile using the buffer
output_buffer.seek(0)
return InputImpl.VideoFromFile(output_buffer)
except Exception as e:
# Clean up on error
if input_container is not None:
input_container.close()
if output_container is not None:
output_container.close()
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
def parse_width_height_from_res(resolution: str):
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
res_map = {
@@ -338,19 +149,14 @@ def parse_control_parameter(value):
return control_map.get(value, control_map["Motion Transfer"])
async def get_response(
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> MoonvalleyPromptResponse:
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MoonvalleyPromptResponse,
),
result_url_extractor=get_video_url_from_response,
node_id=node_id,
async def get_response(cls: type[IO.ComfyNode], task_id: str) -> MoonvalleyPromptResponse:
return await poll_op(
cls,
ApiEndpoint(path=f"{API_PROMPTS_ENDPOINT}/{task_id}"),
response_model=MoonvalleyPromptResponse,
status_extractor=lambda r: (r.status if r and r.status else None),
poll_interval=16.0,
max_poll_attempts=240,
)
@@ -444,14 +250,10 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
steps: int,
) -> IO.NodeOutput:
validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = parse_width_height_from_res(resolution)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
inference_params = MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=steps,
@@ -464,33 +266,17 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
# Get MIME type from tensor - assuming PNG format for image tensors
mime_type = "image/png"
image_url = (
await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=auth, mime_type=mime_type
)
)[0]
request = MoonvalleyTextToVideoRequest(
image_url=image_url, prompt_text=prompt, inference_params=inference_params
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=API_IMG2VIDEO_ENDPOINT,
method=HttpMethod.POST,
request_model=MoonvalleyTextToVideoRequest,
response_model=MoonvalleyPromptResponse,
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0]
task_creation_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, method="POST"),
response_model=MoonvalleyPromptResponse,
data=MoonvalleyTextToVideoRequest(
image_url=image_url, prompt_text=prompt, inference_params=inference_params
),
request=request,
auth_kwargs=auth,
)
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = await get_response(
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
)
final_response = await get_response(cls, task_creation_response.id)
video = await download_url_to_video_output(final_response.output_url)
return IO.NodeOutput(video)
@@ -582,15 +368,10 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
steps=33,
prompt_adherence=4.5,
) -> IO.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
validated_video = validate_video_to_video_input(video)
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth)
validate_prompts(prompt, negative_prompt)
video_url = await upload_video_to_comfyapi(cls, validated_video)
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
# Only include motion_intensity for Motion Transfer
control_params = {}
@@ -605,35 +386,20 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
guidance_scale=prompt_adherence,
)
control = parse_control_parameter(control_type)
request = MoonvalleyVideoToVideoRequest(
control_type=control,
video_url=video_url,
prompt_text=prompt,
inference_params=inference_params,
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=API_VIDEO2VIDEO_ENDPOINT,
method=HttpMethod.POST,
request_model=MoonvalleyVideoToVideoRequest,
response_model=MoonvalleyPromptResponse,
task_creation_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT, method="POST"),
response_model=MoonvalleyPromptResponse,
data=MoonvalleyVideoToVideoRequest(
control_type=parse_control_parameter(control_type),
video_url=video_url,
prompt_text=prompt,
inference_params=inference_params,
),
request=request,
auth_kwargs=auth,
)
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = await get_response(
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
)
video = await download_url_to_video_output(final_response.output_url)
return IO.NodeOutput(video)
final_response = await get_response(cls, task_creation_response.id)
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
class MoonvalleyTxt2VideoNode(IO.ComfyNode):
@@ -720,14 +486,10 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode):
seed: int,
steps: int,
) -> IO.NodeOutput:
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = parse_width_height_from_res(resolution)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
inference_params = MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=steps,
@@ -737,30 +499,16 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode):
width=width_height["width"],
height=width_height["height"],
)
request = MoonvalleyTextToVideoRequest(
prompt_text=prompt, inference_params=inference_params
)
init_op = SynchronousOperation(
endpoint=ApiEndpoint(
path=API_TXT2VIDEO_ENDPOINT,
method=HttpMethod.POST,
request_model=MoonvalleyTextToVideoRequest,
response_model=MoonvalleyPromptResponse,
),
request=request,
auth_kwargs=auth,
task_creation_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT, method="POST"),
response_model=MoonvalleyPromptResponse,
data=MoonvalleyTextToVideoRequest(prompt_text=prompt, inference_params=inference_params),
)
task_creation_response = await init_op.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = await get_response(
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
)
video = await download_url_to_video_output(final_response.output_url)
return IO.NodeOutput(video)
final_response = await get_response(cls, task_creation_response.id)
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
class MoonvalleyExtension(ComfyExtension):

File diff suppressed because it is too large Load Diff

View File

@@ -7,28 +7,23 @@ from __future__ import annotations
from io import BytesIO
import logging
from typing import Optional, TypeVar
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
from comfy_api_nodes.apinode_utils import (
from comfy_api_nodes.apis import pika_api as pika_defs
from comfy_api_nodes.util import (
validate_string,
download_url_to_video_output,
tensor_to_bytesio,
validate_string,
)
from comfy_api_nodes.apis import pika_defs
from comfy_api_nodes.apis.client import (
ApiEndpoint,
EmptyRequest,
HttpMethod,
PollingOperation,
SynchronousOperation,
sync_op,
poll_op,
)
R = TypeVar("R")
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
@@ -44,28 +39,18 @@ PATH_VIDEO_GET = "/proxy/pika/videos"
async def execute_task(
initial_operation: SynchronousOperation[R, pika_defs.PikaGenerateResponse],
auth_kwargs: Optional[dict[str, str]] = None,
node_id: Optional[str] = None,
task_id: str,
cls: type[IO.ComfyNode],
) -> IO.NodeOutput:
task_id = (await initial_operation.execute()).video_id
final_response: pika_defs.PikaVideoResponse = await PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"{PATH_VIDEO_GET}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=pika_defs.PikaVideoResponse,
),
completed_statuses=["finished"],
failed_statuses=["failed", "cancelled"],
final_response: pika_defs.PikaVideoResponse = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
response_model=pika_defs.PikaVideoResponse,
status_extractor=lambda response: (response.status.value if response.status else None),
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
auth_kwargs=auth_kwargs,
result_url_extractor=lambda response: (response.url if hasattr(response, "url") else None),
node_id=node_id,
estimated_duration=60,
max_poll_attempts=240,
).execute()
)
if not final_response.url:
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
logging.error(error_msg)
@@ -128,23 +113,15 @@ class PikaImageToVideo(IO.ComfyNode):
resolution=resolution,
duration=duration,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)
class PikaTextToVideoNode(IO.ComfyNode):
@@ -187,18 +164,11 @@ class PikaTextToVideoNode(IO.ComfyNode):
duration: int,
aspect_ratio: float,
) -> IO.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_TEXT_TO_VIDEO,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
@@ -206,10 +176,9 @@ class PikaTextToVideoNode(IO.ComfyNode):
duration=duration,
aspectRatio=aspect_ratio,
),
auth_kwargs=auth,
content_type="application/x-www-form-urlencoded",
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)
class PikaScenes(IO.ComfyNode):
@@ -313,24 +282,16 @@ class PikaScenes(IO.ComfyNode):
duration=duration,
aspectRatio=aspect_ratio,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKASCENES,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)
class PikAdditionsNode(IO.ComfyNode):
@@ -387,24 +348,16 @@ class PikAdditionsNode(IO.ComfyNode):
negativePrompt=negative_prompt,
seed=seed,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKADDITIONS,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)
class PikaSwapsNode(IO.ComfyNode):
@@ -476,23 +429,15 @@ class PikaSwapsNode(IO.ComfyNode):
seed=seed,
modifyRegionRoi=region_to_modify if region_to_modify else None,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKASWAPS,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)
class PikaffectsNode(IO.ComfyNode):
@@ -532,18 +477,11 @@ class PikaffectsNode(IO.ComfyNode):
negative_prompt: str,
seed: int,
) -> IO.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKAFFECTS,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
pikaffect=pikaffect,
promptText=prompt_text,
negativePrompt=negative_prompt,
@@ -551,9 +489,8 @@ class PikaffectsNode(IO.ComfyNode):
),
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)
class PikaStartEndFrameNode(IO.ComfyNode):
@@ -596,18 +533,11 @@ class PikaStartEndFrameNode(IO.ComfyNode):
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
]
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKAFRAMES,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
@@ -616,9 +546,8 @@ class PikaStartEndFrameNode(IO.ComfyNode):
),
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)
class PikaApiNodesExtension(ComfyExtension):

View File

@@ -1,7 +1,6 @@
from inspect import cleandoc
from typing import Optional
import torch
from typing_extensions import override
from io import BytesIO
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.pixverse_api import (
PixverseTextVideoRequest,
PixverseImageVideoRequest,
@@ -17,59 +16,30 @@ from comfy_api_nodes.apis.pixverse_api import (
PixverseIO,
pixverse_templates,
)
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_video_output,
poll_op,
sync_op,
tensor_to_bytesio,
validate_string,
)
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, IO
import torch
import aiohttp
AVERAGE_DURATION_T2V = 32
AVERAGE_DURATION_I2V = 30
AVERAGE_DURATION_T2T = 52
def get_video_url_from_response(
response: PixverseGenerationStatusResponse,
) -> Optional[str]:
if response.Resp is None or response.Resp.url is None:
return None
return str(response.Resp.url)
async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
# first, upload image to Pixverse and get image id to use in actual generation call
files = {"image": tensor_to_bytesio(image)}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/image/upload",
method=HttpMethod.POST,
request_model=EmptyRequest,
response_model=PixverseImageUploadResponse,
),
request=EmptyRequest(),
files=files,
async def upload_image_to_pixverse(cls: type[IO.ComfyNode], image: torch.Tensor):
response_upload = await sync_op(
cls,
ApiEndpoint(path="/proxy/pixverse/image/upload", method="POST"),
response_model=PixverseImageUploadResponse,
files={"image": tensor_to_bytesio(image)},
content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
)
response_upload: PixverseImageUploadResponse = await operation.execute()
if response_upload.Resp is None:
raise Exception(
f"PixVerse image upload request failed: '{response_upload.ErrMsg}'"
)
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
return response_upload.Resp.img_id
@@ -95,22 +65,17 @@ class PixverseTemplateNode(IO.ComfyNode):
template_id = pixverse_templates.get(template, None)
if template_id is None:
raise Exception(f"Template '{template}' is not recognized.")
# just return the integer
return IO.NodeOutput(template_id)
class PixverseTextToVideoNode(IO.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PixverseTextToVideoNode",
display_name="PixVerse Text to Video",
category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos based on prompt and output_size.",
inputs=[
IO.String.Input(
"prompt",
@@ -177,7 +142,7 @@ class PixverseTextToVideoNode(IO.ComfyNode):
negative_prompt: str = None,
pixverse_template: int = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
validate_string(prompt, strip_whitespace=False, min_length=1)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
if quality == PixverseQuality.res_1080p:
@@ -186,18 +151,11 @@ class PixverseTextToVideoNode(IO.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/text/generate",
method=HttpMethod.POST,
request_model=PixverseTextVideoRequest,
response_model=PixverseVideoResponse,
),
request=PixverseTextVideoRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/pixverse/video/text/generate", method="POST"),
response_model=PixverseVideoResponse,
data=PixverseTextVideoRequest(
prompt=prompt,
aspect_ratio=aspect_ratio,
quality=quality,
@@ -207,20 +165,14 @@ class PixverseTextToVideoNode(IO.ComfyNode):
template_id=pixverse_template,
seed=seed,
),
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
response_model=PixverseGenerationStatusResponse,
completed_statuses=[PixverseStatus.successful],
failed_statuses=[
PixverseStatus.contents_moderation,
@@ -228,30 +180,19 @@ class PixverseTextToVideoNode(IO.ComfyNode):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
class PixverseImageToVideoNode(IO.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PixverseImageToVideoNode",
display_name="PixVerse Image to Video",
category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos based on prompt and output_size.",
inputs=[
IO.Image.Input("image"),
IO.String.Input(
@@ -316,11 +257,7 @@ class PixverseImageToVideoNode(IO.ComfyNode):
pixverse_template: int = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
img_id = await upload_image_to_pixverse(image, auth_kwargs=auth)
img_id = await upload_image_to_pixverse(cls, image)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@@ -330,14 +267,11 @@ class PixverseImageToVideoNode(IO.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/img/generate",
method=HttpMethod.POST,
request_model=PixverseImageVideoRequest,
response_model=PixverseVideoResponse,
),
request=PixverseImageVideoRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/pixverse/video/img/generate", method="POST"),
response_model=PixverseVideoResponse,
data=PixverseImageVideoRequest(
img_id=img_id,
prompt=prompt,
quality=quality,
@@ -347,20 +281,15 @@ class PixverseImageToVideoNode(IO.ComfyNode):
template_id=pixverse_template,
seed=seed,
),
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
response_model=PixverseGenerationStatusResponse,
completed_statuses=[PixverseStatus.successful],
failed_statuses=[
PixverseStatus.contents_moderation,
@@ -368,30 +297,19 @@ class PixverseImageToVideoNode(IO.ComfyNode):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
class PixverseTransitionVideoNode(IO.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PixverseTransitionVideoNode",
display_name="PixVerse Transition Video",
category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos based on prompt and output_size.",
inputs=[
IO.Image.Input("first_frame"),
IO.Image.Input("last_frame"),
@@ -452,12 +370,8 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
negative_prompt: str = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth)
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth)
first_frame_id = await upload_image_to_pixverse(cls, first_frame)
last_frame_id = await upload_image_to_pixverse(cls, last_frame)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@@ -467,14 +381,11 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/transition/generate",
method=HttpMethod.POST,
request_model=PixverseTransitionVideoRequest,
response_model=PixverseVideoResponse,
),
request=PixverseTransitionVideoRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/pixverse/video/transition/generate", method="POST"),
response_model=PixverseVideoResponse,
data=PixverseTransitionVideoRequest(
first_frame_img=first_frame_id,
last_frame_img=last_frame_id,
prompt=prompt,
@@ -484,20 +395,15 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
negative_prompt=negative_prompt if negative_prompt else None,
seed=seed,
),
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
response_model=PixverseGenerationStatusResponse,
completed_statuses=[PixverseStatus.successful],
failed_statuses=[
PixverseStatus.contents_moderation,
@@ -505,16 +411,9 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
class PixVerseExtension(ComfyExtension):

File diff suppressed because it is too large Load Diff

View File

@@ -225,21 +225,20 @@ async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] =
async def download_files(url_list, task_uuid):
save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}")
result_folder_name = f"Rodin3D_{task_uuid}"
save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name)
os.makedirs(save_path, exist_ok=True)
model_file_path = None
async with aiohttp.ClientSession() as session:
for i in url_list.list:
url = i.url
file_name = i.name
file_path = os.path.join(save_path, file_name)
file_path = os.path.join(save_path, i.name)
if file_path.endswith(".glb"):
model_file_path = file_path
model_file_path = os.path.join(result_folder_name, i.name)
logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path)
max_retries = 5
for attempt in range(max_retries):
try:
async with session.get(url) as resp:
async with session.get(i.url) as resp:
resp.raise_for_status()
with open(file_path, "wb") as f:
async for chunk in resp.content.iter_chunked(32 * 1024):

View File

@@ -11,7 +11,7 @@ User Guides:
"""
from typing import Union, Optional, Any
from typing import Union, Optional
from typing_extensions import override
from enum import Enum
@@ -21,7 +21,6 @@ from comfy_api_nodes.apis import (
RunwayImageToVideoRequest,
RunwayImageToVideoResponse,
RunwayTaskStatusResponse as TaskStatusResponse,
RunwayTaskStatusEnum as TaskStatus,
RunwayModelEnum as Model,
RunwayDurationEnum as Duration,
RunwayAspectRatioEnum as AspectRatio,
@@ -33,23 +32,20 @@ from comfy_api_nodes.apis import (
ReferenceImage,
RunwayTextToImageAspectRatioEnum,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
upload_images_to_comfyapi,
download_url_to_video_output,
from comfy_api_nodes.util import (
image_tensor_pair_to_batch,
validate_string,
validate_image_dimensions,
validate_image_aspect_ratio,
upload_images_to_comfyapi,
download_url_to_video_output,
download_url_to_image_tensor,
ApiEndpoint,
sync_op,
poll_op,
)
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, IO
from comfy_api_nodes.util.validation_utils import validate_image_dimensions, validate_image_aspect_ratio
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
@@ -91,31 +87,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
return None
async def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
estimated_duration: Optional[int] = None,
node_id: Optional[str] = None,
) -> TaskStatusResponse:
"""Polls the Runway API endpoint until the task reaches a terminal state, then returns the response."""
return await PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[
TaskStatus.SUCCEEDED.value,
],
failed_statuses=[
TaskStatus.FAILED.value,
TaskStatus.CANCELLED.value,
],
status_extractor=lambda response: response.status.value,
auth_kwargs=auth_kwargs,
result_url_extractor=get_video_url_from_task_status,
estimated_duration=estimated_duration,
node_id=node_id,
progress_extractor=extract_progress_from_task_status,
).execute()
def extract_progress_from_task_status(
response: TaskStatusResponse,
) -> Union[float, None]:
@@ -132,42 +103,32 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
async def get_response(
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None, estimated_duration: Optional[int] = None
cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None
) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response."""
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
return await poll_op(
cls,
ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: r.status.value,
estimated_duration=estimated_duration,
node_id=node_id,
progress_extractor=extract_progress_from_task_status,
)
async def generate_video(
cls: type[IO.ComfyNode],
request: RunwayImageToVideoRequest,
auth_kwargs: dict[str, str],
node_id: Optional[str] = None,
estimated_duration: Optional[int] = None,
) -> VideoFromFile:
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST,
request_model=RunwayImageToVideoRequest,
response_model=RunwayImageToVideoResponse,
),
request=request,
auth_kwargs=auth_kwargs,
initial_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
response_model=RunwayImageToVideoResponse,
data=request,
)
initial_response = await initial_operation.execute()
final_response = await get_response(initial_response.id, auth_kwargs, node_id, estimated_duration)
final_response = await get_response(cls, initial_response.id, estimated_duration)
if not final_response.output:
raise RunwayApiError("Runway task succeeded but no video data found in response.")
@@ -184,9 +145,9 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
display_name="Runway Image to Video (Gen3a Turbo)",
category="api node/video/Runway",
description="Generate a video from a single starting frame using Gen3a Turbo model. "
"Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.",
"Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.",
inputs=[
IO.String.Input(
"prompt",
@@ -239,22 +200,18 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
) -> IO.NodeOutput:
validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
download_urls = await upload_images_to_comfyapi(
cls,
start_frame,
max_images=1,
mime_type="image/png",
auth_kwargs=auth_kwargs,
)
return IO.NodeOutput(
await generate_video(
cls,
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
@@ -262,15 +219,9 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")]
),
),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
)
)
@@ -284,9 +235,9 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
display_name="Runway Image to Video (Gen4 Turbo)",
category="api node/video/Runway",
description="Generate a video from a single starting frame using Gen4 Turbo model. "
"Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.",
"Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.",
inputs=[
IO.String.Input(
"prompt",
@@ -339,22 +290,18 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
) -> IO.NodeOutput:
validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
download_urls = await upload_images_to_comfyapi(
cls,
start_frame,
max_images=1,
mime_type="image/png",
auth_kwargs=auth_kwargs,
)
return IO.NodeOutput(
await generate_video(
cls,
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
@@ -362,15 +309,9 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")]
),
),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
)
)
@@ -385,12 +326,12 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
display_name="Runway First-Last-Frame to Video",
category="api node/video/Runway",
description="Upload first and last keyframes, draft a prompt, and generate a video. "
"More complex transitions, such as cases where the Last frame is completely different "
"from the First frame, may benefit from the longer 10s duration. "
"This would give the generation more time to smoothly transition between the two inputs. "
"Before diving in, review these best practices to ensure that your input selections "
"will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.",
"More complex transitions, such as cases where the Last frame is completely different "
"from the First frame, may benefit from the longer 10s duration. "
"This would give the generation more time to smoothly transition between the two inputs. "
"Before diving in, review these best practices to ensure that your input selections "
"will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.",
inputs=[
IO.String.Input(
"prompt",
@@ -449,26 +390,22 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_dimensions(end_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
validate_image_aspect_ratio(end_frame, (1, 2), (2, 1))
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
download_urls = await upload_images_to_comfyapi(
cls,
stacked_input_images,
max_images=2,
mime_type="image/png",
auth_kwargs=auth_kwargs,
)
if len(download_urls) != 2:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return IO.NodeOutput(
await generate_video(
cls,
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
@@ -477,17 +414,11 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
),
RunwayPromptImageDetailedObject(
uri=str(download_urls[1]), position="last"
),
RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first"),
RunwayPromptImageDetailedObject(uri=str(download_urls[1]), position="last"),
]
),
),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
)
)
@@ -502,7 +433,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
display_name="Runway Text to Image",
category="api node/image/Runway",
description="Generate an image from a text prompt using Runway's Gen 4 model. "
"You can also include reference image to guide the generation.",
"You can also include reference image to guide the generation.",
inputs=[
IO.String.Input(
"prompt",
@@ -540,49 +471,34 @@ class RunwayTextToImageNode(IO.ComfyNode):
) -> IO.NodeOutput:
validate_string(prompt, min_length=1)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# Prepare reference images if provided
reference_images = None
if reference_image is not None:
validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
validate_image_aspect_ratio(reference_image, (1, 2), (2, 1))
download_urls = await upload_images_to_comfyapi(
cls,
reference_image,
max_images=1,
mime_type="image/png",
auth_kwargs=auth_kwargs,
)
reference_images = [ReferenceImage(uri=str(download_urls[0]))]
request = RunwayTextToImageRequest(
promptText=prompt,
model=Model4.gen4_image,
ratio=ratio,
referenceImages=reference_images,
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_TEXT_TO_IMAGE,
method=HttpMethod.POST,
request_model=RunwayTextToImageRequest,
response_model=RunwayTextToImageResponse,
initial_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=PATH_TEXT_TO_IMAGE, method="POST"),
response_model=RunwayTextToImageResponse,
data=RunwayTextToImageRequest(
promptText=prompt,
model=Model4.gen4_image,
ratio=ratio,
referenceImages=reference_images,
),
request=request,
auth_kwargs=auth_kwargs,
)
initial_response = await initial_operation.execute()
# Poll for completion
final_response = await get_response(
cls,
initial_response.id,
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
)
if not final_response.output:
@@ -601,5 +517,6 @@ class RunwayExtension(ComfyExtension):
RunwayTextToImageNode,
]
async def comfy_entrypoint() -> RunwayExtension:
return RunwayExtension()

View File

@@ -1,23 +1,20 @@
from typing import Optional
from typing_extensions import override
import torch
from pydantic import BaseModel, Field
from comfy_api.latest import ComfyExtension, IO
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.util.validation_utils import get_number_of_images
from typing_extensions import override
from comfy_api_nodes.apinode_utils import (
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_video_output,
get_number_of_images,
poll_op,
sync_op,
tensor_to_bytesio,
)
class Sora2GenerationRequest(BaseModel):
prompt: str = Field(...)
model: str = Field(...)
@@ -80,7 +77,7 @@ class OpenAIVideoSora2(IO.ComfyNode):
control_after_generate=True,
optional=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
@@ -111,55 +108,34 @@ class OpenAIVideoSora2(IO.ComfyNode):
if get_number_of_images(image) != 1:
raise ValueError("Currently only one input image is supported.")
files_input = {"input_reference": ("image.png", tensor_to_bytesio(image), "image/png")}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
payload = Sora2GenerationRequest(
model=model,
prompt=prompt,
seconds=str(duration),
size=size,
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/openai/v1/videos",
method=HttpMethod.POST,
request_model=Sora2GenerationRequest,
response_model=Sora2GenerationResponse
initial_response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/openai/v1/videos", method="POST"),
data=Sora2GenerationRequest(
model=model,
prompt=prompt,
seconds=str(duration),
size=size,
),
request=payload,
files=files_input,
auth_kwargs=auth,
response_model=Sora2GenerationResponse,
content_type="multipart/form-data",
)
initial_response = await initial_operation.execute()
if initial_response.error:
raise Exception(initial_response.error.message)
raise Exception(initial_response.error["message"])
model_time_multiplier = 1 if model == "sora-2" else 2
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/openai/v1/videos/{initial_response.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=Sora2GenerationResponse
),
completed_statuses=["completed"],
failed_statuses=["failed"],
await poll_op(
cls,
poll_endpoint=ApiEndpoint(path=f"/proxy/openai/v1/videos/{initial_response.id}"),
response_model=Sora2GenerationResponse,
status_extractor=lambda x: x.status,
auth_kwargs=auth,
poll_interval=8.0,
max_poll_attempts=160,
node_id=cls.hidden.unique_id,
estimated_duration=45 * (duration / 4) * model_time_multiplier,
estimated_duration=int(45 * (duration / 4) * model_time_multiplier),
)
await poll_operation.execute()
return IO.NodeOutput(
await download_url_to_video_output(
f"/proxy/openai/v1/videos/{initial_response.id}/content",
auth_kwargs=auth,
)
await download_url_to_video_output(f"/proxy/openai/v1/videos/{initial_response.id}/content", cls=cls),
)

View File

@@ -20,21 +20,17 @@ from comfy_api_nodes.apis.stability_api import (
StabilityAudioInpaintRequest,
StabilityAudioResponse,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
from comfy_api_nodes.util import (
validate_audio_duration,
validate_string,
audio_input_to_mp3,
bytesio_to_image_tensor,
tensor_to_bytesio,
validate_string,
audio_bytes_to_audio_input,
audio_input_to_mp3,
sync_op,
poll_op,
ApiEndpoint,
)
from comfy_api_nodes.util.validation_utils import validate_audio_duration
import torch
import base64
@@ -161,19 +157,11 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/generate/ultra",
method=HttpMethod.POST,
request_model=StabilityStableUltraRequest,
response_model=StabilityStableUltraResponse,
),
request=StabilityStableUltraRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"),
response_model=StabilityStableUltraResponse,
data=StabilityStableUltraRequest(
prompt=prompt,
negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio,
@@ -183,9 +171,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
),
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
@@ -313,19 +299,11 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/generate/sd3",
method=HttpMethod.POST,
request_model=StabilityStable3_5Request,
response_model=StabilityStableUltraResponse,
),
request=StabilityStable3_5Request(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"),
response_model=StabilityStableUltraResponse,
data=StabilityStable3_5Request(
prompt=prompt,
negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio,
@@ -338,9 +316,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
),
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
@@ -427,19 +403,11 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/conservative",
method=HttpMethod.POST,
request_model=StabilityUpscaleConservativeRequest,
response_model=StabilityStableUltraResponse,
),
request=StabilityUpscaleConservativeRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"),
response_model=StabilityStableUltraResponse,
data=StabilityUpscaleConservativeRequest(
prompt=prompt,
negative_prompt=negative_prompt,
creativity=round(creativity,2),
@@ -447,9 +415,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
),
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
@@ -544,19 +510,11 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/creative",
method=HttpMethod.POST,
request_model=StabilityUpscaleCreativeRequest,
response_model=StabilityAsyncResponse,
),
request=StabilityUpscaleCreativeRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"),
response_model=StabilityAsyncResponse,
data=StabilityUpscaleCreativeRequest(
prompt=prompt,
negative_prompt=negative_prompt,
creativity=round(creativity,2),
@@ -565,25 +523,15 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
),
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
response_api = await operation.execute()
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/stability/v2beta/results/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=StabilityResultsGetResponse,
),
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"),
response_model=StabilityResultsGetResponse,
poll_interval=3,
completed_statuses=[StabilityPollStatus.finished],
failed_statuses=[StabilityPollStatus.failed],
status_extractor=lambda x: get_async_dummy_status(x),
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
)
response_poll: StabilityResultsGetResponse = await operation.execute()
if response_poll.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
@@ -628,24 +576,13 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/fast",
method=HttpMethod.POST,
request_model=EmptyRequest,
response_model=StabilityStableUltraResponse,
),
request=EmptyRequest(),
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"),
response_model=StabilityStableUltraResponse,
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
@@ -717,21 +654,13 @@ class StabilityTextToAudio(IO.ComfyNode):
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
validate_string(prompt, max_length=10000)
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio",
method=HttpMethod.POST,
request_model=StabilityTextToAudioRequest,
response_model=StabilityAudioResponse,
),
request=payload,
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"),
response_model=StabilityAudioResponse,
data=payload,
content_type="multipart/form-data",
auth_kwargs= {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
response_api = await operation.execute()
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
@@ -814,22 +743,14 @@ class StabilityAudioToAudio(IO.ComfyNode):
payload = StabilityAudioToAudioRequest(
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio",
method=HttpMethod.POST,
request_model=StabilityAudioToAudioRequest,
response_model=StabilityAudioResponse,
),
request=payload,
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"),
response_model=StabilityAudioResponse,
data=payload,
content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)},
auth_kwargs= {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
response_api = await operation.execute()
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
@@ -935,22 +856,14 @@ class StabilityAudioInpaint(IO.ComfyNode):
mask_start=mask_start,
mask_end=mask_end,
)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint",
method=HttpMethod.POST,
request_model=StabilityAudioInpaintRequest,
response_model=StabilityAudioResponse,
),
request=payload,
response_api = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"),
response_model=StabilityAudioResponse,
data=payload,
content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)},
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
response_api = await operation.execute()
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))

File diff suppressed because it is too large Load Diff

View File

@@ -1,28 +1,21 @@
import logging
import base64
import aiohttp
import torch
from io import BytesIO
from typing import Optional
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis import (
VeoGenVidRequest,
VeoGenVidResponse,
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.veo_api import (
VeoGenVidPollRequest,
VeoGenVidPollResponse,
VeoGenVidRequest,
VeoGenVidResponse,
)
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
)
from comfy_api_nodes.apinode_utils import (
downscale_image_tensor,
download_url_to_video_output,
poll_op,
sync_op,
tensor_to_base64_string,
)
@@ -35,28 +28,6 @@ MODELS_MAP = {
"veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001",
}
def convert_image_to_base64(image: torch.Tensor):
if image is None:
return None
scaled_image = downscale_image_tensor(image, total_pixels=2048*2048)
return tensor_to_base64_string(scaled_image)
def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optional[str]:
if (
poll_response.response
and hasattr(poll_response.response, "videos")
and poll_response.response.videos
and len(poll_response.response.videos) > 0
):
video = poll_response.response.videos[0]
else:
return None
if hasattr(video, "gcsUri") and video.gcsUri:
return str(video.gcsUri)
return None
class VeoVideoGenerationNode(IO.ComfyNode):
"""
@@ -169,18 +140,13 @@ class VeoVideoGenerationNode(IO.ComfyNode):
# Prepare the instances for the request
instances = []
instance = {
"prompt": prompt
}
instance = {"prompt": prompt}
# Add image if provided
if image is not None:
image_base64 = convert_image_to_base64(image)
image_base64 = tensor_to_base64_string(image)
if image_base64:
instance["image"] = {
"bytesBase64Encoded": image_base64,
"mimeType": "image/png"
}
instance["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"}
instances.append(instance)
@@ -198,119 +164,77 @@ class VeoVideoGenerationNode(IO.ComfyNode):
if seed > 0:
parameters["seed"] = seed
# Only add generateAudio for Veo 3 models
if "veo-3.0" in model:
if model.find("veo-2.0") == -1:
parameters["generateAudio"] = generate_audio
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# Initial request to start video generation
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=f"/proxy/veo/{model}/generate",
method=HttpMethod.POST,
request_model=VeoGenVidRequest,
response_model=VeoGenVidResponse
),
request=VeoGenVidRequest(
initial_response = await sync_op(
cls,
ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"),
response_model=VeoGenVidResponse,
data=VeoGenVidRequest(
instances=instances,
parameters=parameters
parameters=parameters,
),
auth_kwargs=auth,
)
initial_response = await initial_operation.execute()
operation_name = initial_response.name
logging.info("Veo generation started with operation name: %s", operation_name)
# Define status extractor function
def status_extractor(response):
# Only return "completed" if the operation is done, regardless of success or failure
# We'll check for errors after polling completes
return "completed" if response.done else "pending"
# Define progress extractor function
def progress_extractor(response):
# Could be enhanced if the API provides progress information
return None
# Define the polling operation
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/veo/{model}/poll",
method=HttpMethod.POST,
request_model=VeoGenVidPollRequest,
response_model=VeoGenVidPollResponse
),
completed_statuses=["completed"],
failed_statuses=[], # No failed statuses, we'll handle errors after polling
poll_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"),
response_model=VeoGenVidPollResponse,
status_extractor=status_extractor,
progress_extractor=progress_extractor,
request=VeoGenVidPollRequest(
operationName=operation_name
data=VeoGenVidPollRequest(
operationName=initial_response.name,
),
auth_kwargs=auth,
poll_interval=5.0,
result_url_extractor=get_video_url_from_response,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
)
# Execute the polling operation
poll_response = await poll_operation.execute()
# Now check for errors in the final response
# Check for error in poll response
if hasattr(poll_response, 'error') and poll_response.error:
error_message = f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})"
logging.error(error_message)
raise Exception(error_message)
if poll_response.error:
raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})")
# Check for RAI filtered content
if (hasattr(poll_response.response, 'raiMediaFilteredCount') and
poll_response.response.raiMediaFilteredCount > 0):
if (
hasattr(poll_response.response, "raiMediaFilteredCount")
and poll_response.response.raiMediaFilteredCount > 0
):
# Extract reason message if available
if (hasattr(poll_response.response, 'raiMediaFilteredReasons') and
poll_response.response.raiMediaFilteredReasons):
if (
hasattr(poll_response.response, "raiMediaFilteredReasons")
and poll_response.response.raiMediaFilteredReasons
):
reason = poll_response.response.raiMediaFilteredReasons[0]
error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
else:
error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
logging.error(error_message)
raise Exception(error_message)
# Extract video data
if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0:
if (
poll_response.response
and hasattr(poll_response.response, "videos")
and poll_response.response.videos
and len(poll_response.response.videos) > 0
):
video = poll_response.response.videos[0]
# Check if video is provided as base64 or URL
if hasattr(video, 'bytesBase64Encoded') and video.bytesBase64Encoded:
# Decode base64 string to bytes
video_data = base64.b64decode(video.bytesBase64Encoded)
elif hasattr(video, 'gcsUri') and video.gcsUri:
# Download from URL
async with aiohttp.ClientSession() as session:
async with session.get(video.gcsUri) as video_response:
video_data = await video_response.content.read()
else:
raise Exception("Video returned but no data or URL was provided")
else:
raise Exception("Video generation completed but no video was returned")
if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if not video_data:
raise Exception("No video data was returned")
if hasattr(video, "gcsUri") and video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
logging.info("Video generation completed successfully")
# Convert video data to BytesIO object
video_io = BytesIO(video_data)
# Return VideoFromFile object
return IO.NodeOutput(VideoFromFile(video_io))
raise Exception("Video returned but no data or URL was provided")
raise Exception("Video generation completed but no video was returned")
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
@@ -393,7 +317,12 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
),
IO.Combo.Input(
"model",
options=list(MODELS_MAP.keys()),
options=[
"veo-3.1-generate",
"veo-3.1-fast-generate",
"veo-3.0-generate-001",
"veo-3.0-fast-generate-001",
],
default="veo-3.0-generate-001",
tooltip="Veo 3 model to use for video generation",
optional=True,
@@ -425,5 +354,6 @@ class VeoExtension(ComfyExtension):
Veo3VideoGenerationNode,
]
async def comfy_entrypoint() -> VeoExtension:
return VeoExtension()

View File

@@ -1,27 +1,23 @@
import logging
from enum import Enum
from typing import Any, Callable, Optional, Literal, TypeVar
from typing_extensions import override
from typing import Literal, Optional, TypeVar
import torch
from pydantic import BaseModel, Field
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api_nodes.util.validation_utils import (
validate_aspect_ratio_closeness,
validate_image_dimensions,
validate_image_aspect_ratio_range,
get_number_of_images,
)
from comfy_api_nodes.apis.client import (
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
download_url_to_video_output,
get_number_of_images,
poll_op,
sync_op,
upload_images_to_comfyapi,
validate_image_aspect_ratio,
validate_image_dimensions,
validate_images_aspect_ratio_closeness,
)
from comfy_api_nodes.apinode_utils import download_url_to_video_output, upload_images_to_comfyapi
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video"
@@ -31,8 +27,9 @@ VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations"
R = TypeVar("R")
class VideoModelName(str, Enum):
vidu_q1 = 'viduq1'
vidu_q1 = "viduq1"
class AspectRatio(str, Enum):
@@ -63,17 +60,9 @@ class TaskCreationRequest(BaseModel):
images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL")
class TaskStatus(str, Enum):
created = "created"
queueing = "queueing"
processing = "processing"
success = "success"
failed = "failed"
class TaskCreationResponse(BaseModel):
task_id: str = Field(...)
state: TaskStatus = Field(...)
state: str = Field(...)
created_at: str = Field(...)
code: Optional[int] = Field(None, description="Error code")
@@ -85,32 +74,11 @@ class TaskResult(BaseModel):
class TaskStatusResponse(BaseModel):
state: TaskStatus = Field(...)
state: str = Field(...)
err_code: Optional[str] = Field(None)
creations: list[TaskResult] = Field(..., description="Generated results")
async def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, R],
result_url_extractor: Optional[Callable[[R], str]] = None,
estimated_duration: Optional[int] = None,
node_id: Optional[str] = None,
) -> R:
return await PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[TaskStatus.success.value],
failed_statuses=[TaskStatus.failed.value],
status_extractor=lambda response: response.state.value,
auth_kwargs=auth_kwargs,
result_url_extractor=result_url_extractor,
estimated_duration=estimated_duration,
node_id=node_id,
poll_interval=16.0,
max_poll_attempts=256,
).execute()
def get_video_url_from_response(response) -> Optional[str]:
if response.creations:
return response.creations[0].url
@@ -127,37 +95,27 @@ def get_video_from_response(response) -> TaskResult:
async def execute_task(
cls: type[IO.ComfyNode],
vidu_endpoint: str,
auth_kwargs: Optional[dict[str, str]],
payload: TaskCreationRequest,
estimated_duration: int,
node_id: str,
) -> R:
response = await SynchronousOperation(
endpoint=ApiEndpoint(
path=vidu_endpoint,
method=HttpMethod.POST,
request_model=TaskCreationRequest,
response_model=TaskCreationResponse,
),
request=payload,
auth_kwargs=auth_kwargs,
).execute()
if response.state == TaskStatus.failed:
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"),
response_model=TaskCreationResponse,
data=payload,
)
if response.state == "failed":
error_msg = f"Vidu request failed. Code: {response.code}"
logging.error(error_msg)
raise RuntimeError(error_msg)
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=VIDU_GET_GENERATION_STATUS % response.task_id,
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
result_url_extractor=get_video_url_from_response,
return await poll_op(
cls,
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id),
response_model=TaskStatusResponse,
status_extractor=lambda r: r.state,
estimated_duration=estimated_duration,
node_id=node_id,
)
@@ -258,11 +216,7 @@ class ViduTextToVideoNode(IO.ComfyNode):
resolution=resolution,
movement_amplitude=movement_amplitude,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
results = await execute_task(VIDU_TEXT_TO_VIDEO, auth, payload, 320, cls.hidden.unique_id)
results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload, 320)
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
@@ -353,7 +307,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
) -> IO.NodeOutput:
if get_number_of_images(image) > 1:
raise ValueError("Only one input image is allowed.")
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
validate_image_aspect_ratio(image, (1, 4), (4, 1))
payload = TaskCreationRequest(
model_name=model,
prompt=prompt,
@@ -362,17 +316,13 @@ class ViduImageToVideoNode(IO.ComfyNode):
resolution=resolution,
movement_amplitude=movement_amplitude,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
payload.images = await upload_images_to_comfyapi(
cls,
image,
max_images=1,
mime_type="image/png",
auth_kwargs=auth,
)
results = await execute_task(VIDU_IMAGE_TO_VIDEO, auth, payload, 120, cls.hidden.unique_id)
results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload, 120)
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
@@ -473,7 +423,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
if a > 7:
raise ValueError("Too many images, maximum allowed is 7.")
for image in images:
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
validate_image_aspect_ratio(image, (1, 4), (4, 1))
validate_image_dimensions(image, min_width=128, min_height=128)
payload = TaskCreationRequest(
model_name=model,
@@ -484,17 +434,13 @@ class ViduReferenceVideoNode(IO.ComfyNode):
resolution=resolution,
movement_amplitude=movement_amplitude,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
payload.images = await upload_images_to_comfyapi(
cls,
images,
max_images=7,
mime_type="image/png",
auth_kwargs=auth,
)
results = await execute_task(VIDU_REFERENCE_VIDEO, auth, payload, 120, cls.hidden.unique_id)
results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload, 120)
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
@@ -587,7 +533,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
resolution: str,
movement_amplitude: str,
) -> IO.NodeOutput:
validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
payload = TaskCreationRequest(
model_name=model,
prompt=prompt,
@@ -596,15 +542,11 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
resolution=resolution,
movement_amplitude=movement_amplitude,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
payload.images = [
(await upload_images_to_comfyapi(frame, max_images=1, mime_type="image/png", auth_kwargs=auth))[0]
(await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
for frame in (first_frame, end_frame)
]
results = await execute_task(VIDU_START_END_VIDEO, auth, payload, 96, cls.hidden.unique_id)
results = await execute_task(cls, VIDU_START_END_VIDEO, payload, 96)
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
@@ -618,5 +560,6 @@ class ViduExtension(ComfyExtension):
ViduStartEndToVideoNode,
]
async def comfy_entrypoint() -> ViduExtension:
return ViduExtension()

View File

@@ -1,28 +1,24 @@
import re
from typing import Optional, Type, Union
from typing_extensions import override
from typing import Optional
import torch
from pydantic import BaseModel, Field
from comfy_api.latest import ComfyExtension, Input, IO
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
R,
T,
)
from comfy_api_nodes.util.validation_utils import get_number_of_images, validate_audio_duration
from typing_extensions import override
from comfy_api_nodes.apinode_utils import (
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.util import (
ApiEndpoint,
audio_to_base64_string,
download_url_to_image_tensor,
download_url_to_video_output,
get_number_of_images,
poll_op,
sync_op,
tensor_to_base64_string,
audio_to_base64_string,
validate_audio_duration,
)
class Text2ImageInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
@@ -146,53 +142,7 @@ class VideoTaskStatusResponse(BaseModel):
request_id: str = Field(...)
RES_IN_PARENS = re.compile(r'\((\d+)\s*[x×]\s*(\d+)\)')
async def process_task(
auth_kwargs: dict[str, str],
url: str,
request_model: Type[T],
response_model: Type[R],
payload: Union[
Text2ImageTaskCreationRequest,
Image2ImageTaskCreationRequest,
Text2VideoTaskCreationRequest,
Image2VideoTaskCreationRequest,
],
node_id: str,
estimated_duration: int,
poll_interval: int,
) -> Type[R]:
initial_response = await SynchronousOperation(
endpoint=ApiEndpoint(
path=url,
method=HttpMethod.POST,
request_model=request_model,
response_model=TaskCreationResponse,
),
request=payload,
auth_kwargs=auth_kwargs,
).execute()
if not initial_response.output:
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
return await PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=response_model,
),
completed_statuses=["SUCCEEDED"],
failed_statuses=["FAILED", "CANCELED", "UNKNOWN"],
status_extractor=lambda x: x.output.task_status,
estimated_duration=estimated_duration,
poll_interval=poll_interval,
node_id=node_id,
auth_kwargs=auth_kwargs,
).execute()
RES_IN_PARENS = re.compile(r"\((\d+)\s*[x×]\s*(\d+)\)")
class WanTextToImageApi(IO.ComfyNode):
@@ -259,7 +209,7 @@ class WanTextToImageApi(IO.ComfyNode):
IO.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the result.",
tooltip='Whether to add an "AI generated" watermark to the result.',
optional=True,
),
],
@@ -286,26 +236,28 @@ class WanTextToImageApi(IO.ComfyNode):
prompt_extend: bool = True,
watermark: bool = True,
):
payload = Text2ImageTaskCreationRequest(
model=model,
input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt),
parameters=Txt2ImageParametersField(
size=f"{width}*{height}",
seed=seed,
prompt_extend=prompt_extend,
watermark=watermark,
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/text2image/image-synthesis", method="POST"),
response_model=TaskCreationResponse,
data=Text2ImageTaskCreationRequest(
model=model,
input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt),
parameters=Txt2ImageParametersField(
size=f"{width}*{height}",
seed=seed,
prompt_extend=prompt_extend,
watermark=watermark,
),
),
)
response = await process_task(
{
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
"/proxy/wan/api/v1/services/aigc/text2image/image-synthesis",
request_model=Text2ImageTaskCreationRequest,
if not initial_response.output:
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
response_model=ImageTaskStatusResponse,
payload=payload,
node_id=cls.hidden.unique_id,
status_extractor=lambda x: x.output.task_status,
estimated_duration=9,
poll_interval=3,
)
@@ -320,7 +272,7 @@ class WanImageToImageApi(IO.ComfyNode):
display_name="Wan Image to Image",
category="api node/image/Wan",
description="Generates an image from one or two input images and a text prompt. "
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
inputs=[
IO.Combo.Input(
"model",
@@ -376,7 +328,7 @@ class WanImageToImageApi(IO.ComfyNode):
IO.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the result.",
tooltip='Whether to add an "AI generated" watermark to the result.',
optional=True,
),
],
@@ -408,28 +360,30 @@ class WanImageToImageApi(IO.ComfyNode):
raise ValueError(f"Expected 1 or 2 input images, got {n_images}.")
images = []
for i in image:
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096*4096))
payload = Image2ImageTaskCreationRequest(
model=model,
input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images),
parameters=Image2ImageParametersField(
# size=f"{width}*{height}",
seed=seed,
watermark=watermark,
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096))
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/image2image/image-synthesis", method="POST"),
response_model=TaskCreationResponse,
data=Image2ImageTaskCreationRequest(
model=model,
input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images),
parameters=Image2ImageParametersField(
# size=f"{width}*{height}",
seed=seed,
watermark=watermark,
),
),
)
response = await process_task(
{
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
"/proxy/wan/api/v1/services/aigc/image2image/image-synthesis",
request_model=Image2ImageTaskCreationRequest,
if not initial_response.output:
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
response_model=ImageTaskStatusResponse,
payload=payload,
node_id=cls.hidden.unique_id,
status_extractor=lambda x: x.output.task_status,
estimated_duration=42,
poll_interval=3,
poll_interval=4,
)
return IO.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
@@ -523,7 +477,7 @@ class WanTextToVideoApi(IO.ComfyNode):
IO.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the result.",
tooltip='Whether to add an "AI generated" watermark to the result.',
optional=True,
),
],
@@ -557,28 +511,31 @@ class WanTextToVideoApi(IO.ComfyNode):
if audio is not None:
validate_audio_duration(audio, 3.0, 29.0)
audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame")
payload = Text2VideoTaskCreationRequest(
model=model,
input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url),
parameters=Text2VideoParametersField(
size=f"{width}*{height}",
duration=duration,
seed=seed,
audio=generate_audio,
prompt_extend=prompt_extend,
watermark=watermark,
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"),
response_model=TaskCreationResponse,
data=Text2VideoTaskCreationRequest(
model=model,
input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url),
parameters=Text2VideoParametersField(
size=f"{width}*{height}",
duration=duration,
seed=seed,
audio=generate_audio,
prompt_extend=prompt_extend,
watermark=watermark,
),
),
)
response = await process_task(
{
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
"/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
request_model=Text2VideoTaskCreationRequest,
if not initial_response.output:
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
response_model=VideoTaskStatusResponse,
payload=payload,
node_id=cls.hidden.unique_id,
status_extractor=lambda x: x.output.task_status,
estimated_duration=120 * int(duration / 5),
poll_interval=6,
)
@@ -667,7 +624,7 @@ class WanImageToVideoApi(IO.ComfyNode):
IO.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the result.",
tooltip='Whether to add an "AI generated" watermark to the result.',
optional=True,
),
],
@@ -699,35 +656,37 @@ class WanImageToVideoApi(IO.ComfyNode):
):
if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.")
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000*2000)
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000)
audio_url = None
if audio is not None:
validate_audio_duration(audio, 3.0, 29.0)
audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame")
payload = Image2VideoTaskCreationRequest(
model=model,
input=Image2VideoInputField(
prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url
),
parameters=Image2VideoParametersField(
resolution=resolution,
duration=duration,
seed=seed,
audio=generate_audio,
prompt_extend=prompt_extend,
watermark=watermark,
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"),
response_model=TaskCreationResponse,
data=Image2VideoTaskCreationRequest(
model=model,
input=Image2VideoInputField(
prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url
),
parameters=Image2VideoParametersField(
resolution=resolution,
duration=duration,
seed=seed,
audio=generate_audio,
prompt_extend=prompt_extend,
watermark=watermark,
),
),
)
response = await process_task(
{
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
"/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
request_model=Image2VideoTaskCreationRequest,
if not initial_response.output:
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
response_model=VideoTaskStatusResponse,
payload=payload,
node_id=cls.hidden.unique_id,
status_extractor=lambda x: x.output.task_status,
estimated_duration=120 * int(duration / 5),
poll_interval=6,
)

View File

@@ -0,0 +1,97 @@
from ._helpers import get_fs_object_size
from .client import (
ApiEndpoint,
poll_op,
poll_op_raw,
sync_op,
sync_op_raw,
)
from .conversions import (
audio_bytes_to_audio_input,
audio_input_to_mp3,
audio_to_base64_string,
bytesio_to_image_tensor,
downscale_image_tensor,
image_tensor_pair_to_batch,
pil_to_bytesio,
resize_mask_to_image,
tensor_to_base64_string,
tensor_to_bytesio,
tensor_to_pil,
text_filepath_to_base64_string,
text_filepath_to_data_uri,
trim_video,
video_to_base64_string,
)
from .download_helpers import (
download_url_as_bytesio,
download_url_to_bytesio,
download_url_to_image_tensor,
download_url_to_video_output,
)
from .upload_helpers import (
upload_audio_to_comfyapi,
upload_file_to_comfyapi,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
)
from .validation_utils import (
get_number_of_images,
validate_aspect_ratio_string,
validate_audio_duration,
validate_container_format_is_mp4,
validate_image_aspect_ratio,
validate_image_dimensions,
validate_images_aspect_ratio_closeness,
validate_string,
validate_video_dimensions,
validate_video_duration,
)
__all__ = [
# API client
"ApiEndpoint",
"poll_op",
"poll_op_raw",
"sync_op",
"sync_op_raw",
# Upload helpers
"upload_audio_to_comfyapi",
"upload_file_to_comfyapi",
"upload_images_to_comfyapi",
"upload_video_to_comfyapi",
# Download helpers
"download_url_as_bytesio",
"download_url_to_bytesio",
"download_url_to_image_tensor",
"download_url_to_video_output",
# Conversions
"audio_bytes_to_audio_input",
"audio_input_to_mp3",
"audio_to_base64_string",
"bytesio_to_image_tensor",
"downscale_image_tensor",
"image_tensor_pair_to_batch",
"pil_to_bytesio",
"resize_mask_to_image",
"tensor_to_base64_string",
"tensor_to_bytesio",
"tensor_to_pil",
"text_filepath_to_base64_string",
"text_filepath_to_data_uri",
"trim_video",
"video_to_base64_string",
# Validation utilities
"get_number_of_images",
"validate_aspect_ratio_string",
"validate_audio_duration",
"validate_container_format_is_mp4",
"validate_image_aspect_ratio",
"validate_image_dimensions",
"validate_images_aspect_ratio_closeness",
"validate_string",
"validate_video_dimensions",
"validate_video_duration",
# Misc functions
"get_fs_object_size",
]

View File

@@ -0,0 +1,71 @@
import asyncio
import contextlib
import os
import time
from io import BytesIO
from typing import Callable, Optional, Union
from comfy.cli_args import args
from comfy.model_management import processing_interrupted
from comfy_api.latest import IO
from .common_exceptions import ProcessingInterrupted
def is_processing_interrupted() -> bool:
"""Return True if user/runtime requested interruption."""
return processing_interrupted()
def get_node_id(node_cls: type[IO.ComfyNode]) -> str:
return node_cls.hidden.unique_id
def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
if node_cls.hidden.auth_token_comfy_org:
return {"Authorization": f"Bearer {node_cls.hidden.auth_token_comfy_org}"}
if node_cls.hidden.api_key_comfy_org:
return {"X-API-KEY": node_cls.hidden.api_key_comfy_org}
return {}
def default_base_url() -> str:
return getattr(args, "comfy_api_base", "https://api.comfy.org")
async def sleep_with_interrupt(
seconds: float,
node_cls: Optional[type[IO.ComfyNode]],
label: Optional[str] = None,
start_ts: Optional[float] = None,
estimated_total: Optional[int] = None,
*,
display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None,
):
"""
Sleep in 1s slices while:
- Checking for interruption (raises ProcessingInterrupted).
- Optionally emitting time progress via display_callback (if provided).
"""
end = time.monotonic() + seconds
while True:
if is_processing_interrupted():
raise ProcessingInterrupted("Task cancelled")
now = time.monotonic()
if start_ts is not None and label and display_callback:
with contextlib.suppress(Exception):
display_callback(node_cls, label, int(now - start_ts), estimated_total)
if now >= end:
break
await asyncio.sleep(min(1.0, end - now))
def mimetype_to_extension(mime_type: str) -> str:
"""Converts a MIME type to a file extension."""
return mime_type.split("/")[-1].lower()
def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int:
if isinstance(path_or_object, str):
return os.path.getsize(path_or_object)
return len(path_or_object.getvalue())

View File

@@ -0,0 +1,936 @@
import asyncio
import contextlib
import json
import logging
import time
import uuid
from dataclasses import dataclass
from enum import Enum
from io import BytesIO
from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
from urllib.parse import urljoin, urlparse
import aiohttp
from aiohttp.client_exceptions import ClientError, ContentTypeError
from pydantic import BaseModel
from comfy import utils
from comfy_api.latest import IO
from comfy_api_nodes.apis import request_logger
from server import PromptServer
from ._helpers import (
default_base_url,
get_auth_header,
get_node_id,
is_processing_interrupted,
sleep_with_interrupt,
)
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
M = TypeVar("M", bound=BaseModel)
class ApiEndpoint:
def __init__(
self,
path: str,
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
*,
query_params: Optional[dict[str, Any]] = None,
headers: Optional[dict[str, str]] = None,
):
self.path = path
self.method = method
self.query_params = query_params or {}
self.headers = headers or {}
@dataclass
class _RequestConfig:
node_cls: type[IO.ComfyNode]
endpoint: ApiEndpoint
timeout: float
content_type: str
data: Optional[dict[str, Any]]
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]]
multipart_parser: Optional[Callable]
max_retries: int
retry_delay: float
retry_backoff: float
wait_label: str = "Waiting"
monitor_progress: bool = True
estimated_total: Optional[int] = None
final_label_on_success: Optional[str] = "Completed"
progress_origin_ts: Optional[float] = None
@dataclass
class _PollUIState:
started: float
status_label: str = "Queued"
is_queued: bool = True
price: Optional[float] = None
estimated_duration: Optional[int] = None
base_processing_elapsed: float = 0.0 # sum of completed active intervals
active_since: Optional[float] = None # start time of current active interval (None if queued)
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished"]
FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
async def sync_op(
cls: type[IO.ComfyNode],
endpoint: ApiEndpoint,
*,
response_model: Type[M],
data: Optional[BaseModel] = None,
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
content_type: str = "application/json",
timeout: float = 3600.0,
multipart_parser: Optional[Callable] = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: str = "Waiting for server",
estimated_duration: Optional[int] = None,
final_label_on_success: Optional[str] = "Completed",
progress_origin_ts: Optional[float] = None,
monitor_progress: bool = True,
) -> M:
raw = await sync_op_raw(
cls,
endpoint,
data=data,
files=files,
content_type=content_type,
timeout=timeout,
multipart_parser=multipart_parser,
max_retries=max_retries,
retry_delay=retry_delay,
retry_backoff=retry_backoff,
wait_label=wait_label,
estimated_duration=estimated_duration,
as_binary=False,
final_label_on_success=final_label_on_success,
progress_origin_ts=progress_origin_ts,
monitor_progress=monitor_progress,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
return _validate_or_raise(response_model, raw)
async def poll_op(
cls: type[IO.ComfyNode],
poll_endpoint: ApiEndpoint,
*,
response_model: Type[M],
status_extractor: Callable[[M], Optional[Union[str, int]]],
progress_extractor: Optional[Callable[[M], Optional[int]]] = None,
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
completed_statuses: Optional[list[Union[str, int]]] = None,
failed_statuses: Optional[list[Union[str, int]]] = None,
queued_statuses: Optional[list[Union[str, int]]] = None,
data: Optional[BaseModel] = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0,
estimated_duration: Optional[int] = None,
cancel_endpoint: Optional[ApiEndpoint] = None,
cancel_timeout: float = 10.0,
) -> M:
raw = await poll_op_raw(
cls,
poll_endpoint=poll_endpoint,
status_extractor=_wrap_model_extractor(response_model, status_extractor),
progress_extractor=_wrap_model_extractor(response_model, progress_extractor),
price_extractor=_wrap_model_extractor(response_model, price_extractor),
completed_statuses=completed_statuses,
failed_statuses=failed_statuses,
queued_statuses=queued_statuses,
data=data,
poll_interval=poll_interval,
max_poll_attempts=max_poll_attempts,
timeout_per_poll=timeout_per_poll,
max_retries_per_poll=max_retries_per_poll,
retry_delay_per_poll=retry_delay_per_poll,
retry_backoff_per_poll=retry_backoff_per_poll,
estimated_duration=estimated_duration,
cancel_endpoint=cancel_endpoint,
cancel_timeout=cancel_timeout,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
return _validate_or_raise(response_model, raw)
async def sync_op_raw(
cls: type[IO.ComfyNode],
endpoint: ApiEndpoint,
*,
data: Optional[Union[dict[str, Any], BaseModel]] = None,
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
content_type: str = "application/json",
timeout: float = 3600.0,
multipart_parser: Optional[Callable] = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: str = "Waiting for server",
estimated_duration: Optional[int] = None,
as_binary: bool = False,
final_label_on_success: Optional[str] = "Completed",
progress_origin_ts: Optional[float] = None,
monitor_progress: bool = True,
) -> Union[dict[str, Any], bytes]:
"""
Make a single network request.
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
- If as_binary=True: returns bytes.
"""
if isinstance(data, BaseModel):
data = data.model_dump(exclude_none=True)
for k, v in list(data.items()):
if isinstance(v, Enum):
data[k] = v.value
cfg = _RequestConfig(
node_cls=cls,
endpoint=endpoint,
timeout=timeout,
content_type=content_type,
data=data,
files=files,
multipart_parser=multipart_parser,
max_retries=max_retries,
retry_delay=retry_delay,
retry_backoff=retry_backoff,
wait_label=wait_label,
monitor_progress=monitor_progress,
estimated_total=estimated_duration,
final_label_on_success=final_label_on_success,
progress_origin_ts=progress_origin_ts,
)
return await _request_base(cfg, expect_binary=as_binary)
async def poll_op_raw(
cls: type[IO.ComfyNode],
poll_endpoint: ApiEndpoint,
*,
status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]],
progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None,
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
completed_statuses: Optional[list[Union[str, int]]] = None,
failed_statuses: Optional[list[Union[str, int]]] = None,
queued_statuses: Optional[list[Union[str, int]]] = None,
data: Optional[Union[dict[str, Any], BaseModel]] = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0,
estimated_duration: Optional[int] = None,
cancel_endpoint: Optional[ApiEndpoint] = None,
cancel_timeout: float = 10.0,
) -> dict[str, Any]:
"""
Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing,
checks interruption every second, and calls Cancel endpoint (if provided) on interruption.
Uses default complete, failed and queued states assumption.
Returns the final JSON response from the poll endpoint.
"""
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses)
failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses)
queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses)
started = time.monotonic()
consumed_attempts = 0 # counts only non-queued polls
progress_bar = utils.ProgressBar(100) if progress_extractor else None
last_progress: Optional[int] = None
state = _PollUIState(started=started, estimated_duration=estimated_duration)
stop_ticker = asyncio.Event()
async def _ticker():
"""Emit a UI update every second while polling is in progress."""
try:
while not stop_ticker.is_set():
if is_processing_interrupted():
break
now = time.monotonic()
proc_elapsed = state.base_processing_elapsed + (
(now - state.active_since) if state.active_since is not None else 0.0
)
_display_time_progress(
cls,
status=state.status_label,
elapsed_seconds=int(now - state.started),
estimated_total=state.estimated_duration,
price=state.price,
is_queued=state.is_queued,
processing_elapsed_seconds=int(proc_elapsed),
)
await asyncio.sleep(1.0)
except Exception as exc:
logging.debug("Polling ticker exited: %s", exc)
ticker_task = asyncio.create_task(_ticker())
try:
while consumed_attempts < max_poll_attempts:
try:
resp_json = await sync_op_raw(
cls,
poll_endpoint,
data=data,
timeout=timeout_per_poll,
max_retries=max_retries_per_poll,
retry_delay=retry_delay_per_poll,
retry_backoff=retry_backoff_per_poll,
wait_label="Checking",
estimated_duration=None,
as_binary=False,
final_label_on_success=None,
monitor_progress=False,
)
if not isinstance(resp_json, dict):
raise Exception("Polling endpoint returned non-JSON response.")
except ProcessingInterrupted:
if cancel_endpoint:
with contextlib.suppress(Exception):
await sync_op_raw(
cls,
cancel_endpoint,
timeout=cancel_timeout,
max_retries=0,
wait_label="Cancelling task",
estimated_duration=None,
as_binary=False,
final_label_on_success=None,
monitor_progress=False,
)
raise
try:
status = _normalize_status_value(status_extractor(resp_json))
except Exception as e:
logging.error("Status extraction failed: %s", e)
status = None
if price_extractor:
new_price = price_extractor(resp_json)
if new_price is not None:
state.price = new_price
if progress_extractor:
new_progress = progress_extractor(resp_json)
if new_progress is not None and last_progress != new_progress:
progress_bar.update_absolute(new_progress, total=100)
last_progress = new_progress
now_ts = time.monotonic()
is_queued = status in queued_states
if is_queued:
if state.active_since is not None: # If we just moved from active -> queued, close the active interval
state.base_processing_elapsed += now_ts - state.active_since
state.active_since = None
else:
if state.active_since is None: # If we just moved from queued -> active, open a new active interval
state.active_since = now_ts
state.is_queued = is_queued
state.status_label = status or ("Queued" if is_queued else "Processing")
if status in completed_states:
if state.active_since is not None:
state.base_processing_elapsed += now_ts - state.active_since
state.active_since = None
stop_ticker.set()
with contextlib.suppress(Exception):
await ticker_task
if progress_bar and last_progress != 100:
progress_bar.update_absolute(100, total=100)
_display_time_progress(
cls,
status=status if status else "Completed",
elapsed_seconds=int(now_ts - started),
estimated_total=estimated_duration,
price=state.price,
is_queued=False,
processing_elapsed_seconds=int(state.base_processing_elapsed),
)
return resp_json
if status in failed_states:
msg = f"Task failed: {json.dumps(resp_json)}"
logging.error(msg)
raise Exception(msg)
try:
await sleep_with_interrupt(poll_interval, cls, None, None, None)
except ProcessingInterrupted:
if cancel_endpoint:
with contextlib.suppress(Exception):
await sync_op_raw(
cls,
cancel_endpoint,
timeout=cancel_timeout,
max_retries=0,
wait_label="Cancelling task",
estimated_duration=None,
as_binary=False,
final_label_on_success=None,
monitor_progress=False,
)
raise
if not is_queued:
consumed_attempts += 1
raise Exception(
f"Polling timed out after {max_poll_attempts} non-queued attempts "
f"(~{int(max_poll_attempts * poll_interval)}s of active polling)."
)
except ProcessingInterrupted:
raise
except (LocalNetworkError, ApiServerError):
raise
except Exception as e:
raise Exception(f"Polling aborted due to error: {e}") from e
finally:
stop_ticker.set()
with contextlib.suppress(Exception):
await ticker_task
def _display_text(
node_cls: type[IO.ComfyNode],
text: Optional[str],
*,
status: Optional[Union[str, int]] = None,
price: Optional[float] = None,
) -> None:
display_lines: list[str] = []
if status:
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
if price is not None:
display_lines.append(f"Price: ${float(price):,.4f}")
if text is not None:
display_lines.append(text)
if display_lines:
PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls))
def _display_time_progress(
node_cls: type[IO.ComfyNode],
status: Optional[Union[str, int]],
elapsed_seconds: int,
estimated_total: Optional[int] = None,
*,
price: Optional[float] = None,
is_queued: Optional[bool] = None,
processing_elapsed_seconds: Optional[int] = None,
) -> None:
if estimated_total is not None and estimated_total > 0 and is_queued is False:
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
remaining = max(0, int(estimated_total) - int(pe))
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
else:
time_line = f"Time elapsed: {int(elapsed_seconds)}s"
_display_text(node_cls, time_line, status=status, price=price)
async def _diagnose_connectivity() -> dict[str, bool]:
"""Best-effort connectivity diagnostics to distinguish local vs. server issues."""
results = {
"internet_accessible": False,
"api_accessible": False,
}
timeout = aiohttp.ClientTimeout(total=5.0)
async with aiohttp.ClientSession(timeout=timeout) as session:
with contextlib.suppress(ClientError, OSError):
async with session.get("https://www.google.com") as resp:
results["internet_accessible"] = resp.status < 500
if not results["internet_accessible"]:
return results
parsed = urlparse(default_base_url())
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
with contextlib.suppress(ClientError, OSError):
async with session.get(health_url) as resp:
results["api_accessible"] = resp.status < 500
return results
def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
"""Normalize (filename, value, content_type)."""
if len(t) == 2:
return t[0], t[1], "application/octet-stream"
if len(t) == 3:
return t[0], t[1], t[2]
raise ValueError("files tuple must be (filename, file[, content_type])")
def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]:
params = dict(endpoint_params or {})
if method.upper() == "GET" and data:
for k, v in data.items():
if v is not None:
params[k] = v
return params
def _friendly_http_message(status: int, body: Any) -> str:
if status == 401:
return "Unauthorized: Please login first to use this node."
if status == 402:
return "Payment Required: Please add credits to your account to use this node."
if status == 409:
return "There is a problem with your account. Please contact support@comfy.org."
if status == 429:
return "Rate Limit Exceeded: Please try again later."
try:
if isinstance(body, dict):
err = body.get("error")
if isinstance(err, dict):
msg = err.get("message")
typ = err.get("type")
if msg and typ:
return f"API Error: {msg} (Type: {typ})"
if msg:
return f"API Error: {msg}"
return f"API Error: {json.dumps(body)}"
else:
txt = str(body)
if len(txt) <= 200:
return f"API Error (raw): {txt}"
return f"API Error (status {status})"
except Exception:
return f"HTTP {status}: Unknown error"
def _generate_operation_id(method: str, path: str, attempt: int) -> str:
slug = path.strip("/").replace("/", "_") or "op"
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"
def _snapshot_request_body_for_logging(
content_type: str,
method: str,
data: Optional[dict[str, Any]],
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]],
) -> Optional[Union[dict[str, Any], str]]:
if method.upper() == "GET":
return None
if content_type == "multipart/form-data":
form_fields = sorted([k for k, v in (data or {}).items() if v is not None])
file_fields: list[dict[str, str]] = []
if files:
file_iter = files if isinstance(files, list) else list(files.items())
for field_name, file_obj in file_iter:
if file_obj is None:
continue
if isinstance(file_obj, tuple):
filename = file_obj[0]
else:
filename = getattr(file_obj, "name", field_name)
file_fields.append({"field": field_name, "filename": str(filename or "")})
return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields}
if content_type == "application/x-www-form-urlencoded":
return data or {}
return data or {}
async def _request_base(cfg: _RequestConfig, expect_binary: bool):
"""Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors."""
url = cfg.endpoint.path
parsed_url = urlparse(url)
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
method = cfg.endpoint.method
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
async def _monitor(stop_evt: asyncio.Event, start_ts: float):
"""Every second: update elapsed time and signal interruption."""
try:
while not stop_evt.is_set():
if is_processing_interrupted():
return
if cfg.monitor_progress:
_display_time_progress(
cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total
)
await asyncio.sleep(1.0)
except asyncio.CancelledError:
return # normal shutdown
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
attempt = 0
delay = cfg.retry_delay
operation_succeeded: bool = False
final_elapsed_seconds: Optional[int] = None
while True:
attempt += 1
stop_event = asyncio.Event()
monitor_task: Optional[asyncio.Task] = None
sess: Optional[aiohttp.ClientSession] = None
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
payload_headers.update(get_auth_header(cfg.node_cls))
if cfg.endpoint.headers:
payload_headers.update(cfg.endpoint.headers)
payload_kw: dict[str, Any] = {"headers": payload_headers}
if method == "GET":
payload_headers.pop("Content-Type", None)
request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files)
try:
if cfg.monitor_progress:
monitor_task = asyncio.create_task(_monitor(stop_event, start_time))
timeout = aiohttp.ClientTimeout(total=cfg.timeout)
sess = aiohttp.ClientSession(timeout=timeout)
if cfg.content_type == "multipart/form-data" and method != "GET":
# aiohttp will set Content-Type boundary; remove any fixed Content-Type
payload_headers.pop("Content-Type", None)
if cfg.multipart_parser and cfg.data:
form = cfg.multipart_parser(cfg.data)
if not isinstance(form, aiohttp.FormData):
raise ValueError("multipart_parser must return aiohttp.FormData")
else:
form = aiohttp.FormData(default_to_multipart=True)
if cfg.data:
for k, v in cfg.data.items():
if v is None:
continue
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
if cfg.files:
file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items()
for field_name, file_obj in file_iter:
if file_obj is None:
continue
if isinstance(file_obj, tuple):
filename, file_value, content_type = _unpack_tuple(file_obj)
else:
filename = getattr(file_obj, "name", field_name)
file_value = file_obj
content_type = "application/octet-stream"
# Attempt to rewind BytesIO for retries
if isinstance(file_value, BytesIO):
with contextlib.suppress(Exception):
file_value.seek(0)
form.add_field(field_name, file_value, filename=filename, content_type=content_type)
payload_kw["data"] = form
elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET":
payload_headers["Content-Type"] = "application/x-www-form-urlencoded"
payload_kw["data"] = cfg.data or {}
elif method != "GET":
payload_headers["Content-Type"] = "application/json"
payload_kw["json"] = cfg.data or {}
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
)
except Exception as _log_e:
logging.debug("[DEBUG] request logging failed: %s", _log_e)
req_coro = sess.request(method, url, params=params, **payload_kw)
req_task = asyncio.create_task(req_coro)
# Race: request vs. monitor (interruption)
tasks = {req_task}
if monitor_task:
tasks.add(monitor_task)
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
if monitor_task and monitor_task in done:
# Interrupted cancel the request and abort
if req_task in pending:
req_task.cancel()
raise ProcessingInterrupted("Task cancelled")
# Otherwise, request finished
resp = await req_task
async with resp:
if resp.status >= 400:
try:
body = await resp.json()
except (ContentTypeError, json.JSONDecodeError):
body = await resp.text()
if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries:
logging.warning(
"HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).",
method,
url,
resp.status,
delay,
attempt,
cfg.max_retries,
)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
error_message=_friendly_http_message(resp.status, body),
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
await sleep_with_interrupt(
delay,
cfg.node_cls,
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
)
delay *= cfg.retry_backoff
continue
msg = _friendly_http_message(resp.status, body)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
error_message=msg,
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
raise Exception(msg)
if expect_binary:
buff = bytearray()
last_tick = time.monotonic()
async for chunk in resp.content.iter_chunked(64 * 1024):
buff.extend(chunk)
now = time.monotonic()
if now - last_tick >= 1.0:
last_tick = now
if is_processing_interrupted():
raise ProcessingInterrupted("Task cancelled")
if cfg.monitor_progress:
_display_time_progress(
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
)
bytes_payload = bytes(buff)
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=bytes_payload,
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
return bytes_payload
else:
try:
payload = await resp.json()
response_content_to_log: Any = payload
except (ContentTypeError, json.JSONDecodeError):
text = await resp.text()
try:
payload = json.loads(text) if text else {}
except json.JSONDecodeError:
payload = {"_raw": text}
response_content_to_log = payload if isinstance(payload, dict) else text
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=response_content_to_log,
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
return payload
except ProcessingInterrupted:
logging.debug("Polling was interrupted by user")
raise
except (ClientError, OSError) as e:
if attempt <= cfg.max_retries:
logging.warning(
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
method,
url,
delay,
attempt,
cfg.max_retries,
str(e),
)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
except Exception as _log_e:
logging.debug("[DEBUG] request error logging failed: %s", _log_e)
await sleep_with_interrupt(
delay,
cfg.node_cls,
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
)
delay *= cfg.retry_backoff
continue
diag = await _diagnose_connectivity()
if not diag["internet_accessible"]:
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"LocalNetworkError: {str(e)}",
)
except Exception as _log_e:
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
raise LocalNetworkError(
"Unable to connect to the API server due to local network issues. "
"Please check your internet connection and try again."
) from e
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"ApiServerError: {str(e)}",
)
except Exception as _log_e:
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
raise ApiServerError(
f"The API server at {default_base_url()} is currently unreachable. "
f"The service may be experiencing issues."
) from e
finally:
stop_event.set()
if monitor_task:
monitor_task.cancel()
with contextlib.suppress(Exception):
await monitor_task
if sess:
with contextlib.suppress(Exception):
await sess.close()
if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success:
_display_time_progress(
cfg.node_cls,
status=cfg.final_label_on_success,
elapsed_seconds=(
final_elapsed_seconds
if final_elapsed_seconds is not None
else int(time.monotonic() - start_time)
),
estimated_total=cfg.estimated_total,
price=None,
is_queued=False,
processing_elapsed_seconds=final_elapsed_seconds,
)
def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
try:
return response_model.model_validate(payload)
except Exception as e:
logging.error(
"Response validation failed for %s: %s",
getattr(response_model, "__name__", response_model),
e,
)
raise Exception(
f"Response validation failed for {getattr(response_model, '__name__', response_model)}: {e}"
) from e
def _wrap_model_extractor(
response_model: Type[M],
extractor: Optional[Callable[[M], Any]],
) -> Optional[Callable[[dict[str, Any]], Any]]:
"""Wrap a typed extractor so it can be used by the dict-based poller.
Validates the dict into `response_model` before invoking `extractor`.
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
the same response for multiple extractors in a single poll attempt.
"""
if extractor is None:
return None
_cache: dict[int, M] = {}
def _wrapped(d: dict[str, Any]) -> Any:
try:
key = id(d)
model = _cache.get(key)
if model is None:
model = response_model.model_validate(d)
_cache[key] = model
return extractor(model)
except Exception as e:
logging.error("Extractor failed (typed -> dict wrapper): %s", e)
raise
return _wrapped
def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]:
if not values:
return set()
out: set[Union[str, int]] = set()
for v in values:
nv = _normalize_status_value(v)
if nv is not None:
out.add(nv)
return out
def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]:
if isinstance(val, str):
return val.strip().lower()
return val

View File

@@ -0,0 +1,14 @@
class NetworkError(Exception):
"""Base exception for network-related errors with diagnostic information."""
class LocalNetworkError(NetworkError):
"""Exception raised when local network connectivity issues are detected."""
class ApiServerError(NetworkError):
"""Exception raised when the API server is unreachable but internet is working."""
class ProcessingInterrupted(Exception):
"""Operation was interrupted by user/runtime via processing_interrupted()."""

View File

@@ -0,0 +1,470 @@
import base64
import logging
import math
import mimetypes
import uuid
from io import BytesIO
from typing import Optional
import av
import numpy as np
import torch
from PIL import Image
from comfy.utils import common_upscale
from comfy_api.latest import Input, InputImpl
from comfy_api.util import VideoCodec, VideoContainer
from ._helpers import mimetype_to_extension
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
"""Converts image data from BytesIO to a torch.Tensor.
Args:
image_bytesio: BytesIO object containing the image data.
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
Returns:
A torch.Tensor representing the image (1, H, W, C).
Raises:
PIL.UnidentifiedImageError: If the image data cannot be identified.
ValueError: If the specified mode is invalid.
"""
image = Image.open(image_bytesio)
image = image.convert(mode)
image_array = np.array(image).astype(np.float32) / 255.0
return torch.from_numpy(image_array).unsqueeze(0)
def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor:
"""
Converts a pair of image tensors to a batch tensor.
If the images are not the same size, the smaller image is resized to
match the larger image.
"""
if image1.shape[1:] != image2.shape[1:]:
image2 = common_upscale(
image2.movedim(-1, 1),
image1.shape[2],
image1.shape[1],
"bilinear",
"center",
).movedim(1, -1)
return torch.cat((image1, image2), dim=0)
def tensor_to_bytesio(
image: torch.Tensor,
name: Optional[str] = None,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> BytesIO:
"""Converts a torch.Tensor image to a named BytesIO object.
Args:
image: Input torch.Tensor image.
name: Optional filename for the BytesIO object.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns:
Named BytesIO object containing the image data, with pointer set to the start of buffer.
"""
if not mime_type:
mime_type = "image/png"
pil_image = tensor_to_pil(image, total_pixels=total_pixels)
img_binary = pil_to_bytesio(pil_image, mime_type=mime_type)
img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
return img_binary
def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
if len(image.shape) > 3:
image = image[0]
# TODO: remove alpha if not allowed and present
input_tensor = image.cpu()
input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze()
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
return img
def tensor_to_base64_string(
image_tensor: torch.Tensor,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> str:
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
Args:
image_tensor: Input torch.Tensor image.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns:
Base64 encoded string of the image.
"""
pil_image = tensor_to_pil(image_tensor, total_pixels=total_pixels)
img_byte_arr = pil_to_bytesio(pil_image, mime_type=mime_type)
img_bytes = img_byte_arr.getvalue()
# Encode bytes to base64 string
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
return base64_encoded_string
def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
"""Converts a PIL Image to a BytesIO object."""
if not mime_type:
mime_type = "image/png"
img_byte_arr = BytesIO()
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
pil_format = mime_type.split("/")[-1].upper()
if pil_format == "JPG":
pil_format = "JPEG"
img.save(img_byte_arr, format=pil_format)
img_byte_arr.seek(0)
return img_byte_arr
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
"""Downscale input image tensor to roughly the specified total pixels."""
samples = image.movedim(-1, 1)
total = int(total_pixels)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
if scale_by >= 1:
return image
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = common_upscale(samples, width, height, "lanczos", "disabled")
s = s.movedim(1, -1)
return s
def tensor_to_data_uri(
image_tensor: torch.Tensor,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> str:
"""Converts a tensor image to a Data URI string.
Args:
image_tensor: Input torch.Tensor image.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
Returns:
Data URI string (e.g., 'data:image/png;base64,...').
"""
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
return f"data:{mime_type};base64,{base64_string}"
def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", codec_name: str = "aac") -> str:
"""Converts an audio input to a base64 string."""
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
audio_bytes = audio_bytes_io.getvalue()
return base64.b64encode(audio_bytes).decode("utf-8")
def video_to_base64_string(
video: Input.Video,
container_format: VideoContainer = None,
codec: VideoCodec = None
) -> str:
"""
Converts a video input to a base64 string.
Args:
video: The video input to convert
container_format: Optional container format to use (defaults to video.container if available)
codec: Optional codec to use (defaults to video.codec if available)
"""
video_bytes_io = BytesIO()
# Use provided format/codec if specified, otherwise use video's own if available
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
video_bytes_io.seek(0)
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
def audio_ndarray_to_bytesio(
audio_data_np: np.ndarray,
sample_rate: int,
container_format: str = "mp4",
codec_name: str = "aac",
) -> BytesIO:
"""
Encodes a numpy array of audio data into a BytesIO object.
"""
audio_bytes_io = BytesIO()
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
frame = av.AudioFrame.from_ndarray(
audio_data_np,
format="fltp",
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
)
frame.sample_rate = sample_rate
frame.pts = 0
for packet in audio_stream.encode(frame):
output_container.mux(packet)
# Flush stream
for packet in audio_stream.encode(None):
output_container.mux(packet)
audio_bytes_io.seek(0)
return audio_bytes_io
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
"""
Prepares audio waveform for av library by converting to a contiguous numpy array.
Args:
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
Returns:
Contiguous numpy array of the audio waveform. If the audio was batched,
the first item is taken.
"""
if waveform.ndim != 3 or waveform.shape[0] != 1:
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
# If batch is > 1, take first item
if waveform.shape[0] > 1:
waveform = waveform[0]
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
if audio_data_np.dtype != np.float32:
audio_data_np = audio_data_np.astype(np.float32)
return audio_data_np
def audio_input_to_mp3(audio: Input.Audio) -> BytesIO:
waveform = audio["waveform"].cpu()
output_buffer = BytesIO()
output_container = av.open(output_buffer, mode="w", format="mp3")
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
out_stream.bit_rate = 320000
frame = av.AudioFrame.from_ndarray(
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
format="flt",
layout="mono" if waveform.shape[0] == 1 else "stereo",
)
frame.sample_rate = audio["sample_rate"]
frame.pts = 0
output_container.mux(out_stream.encode(frame))
output_container.mux(out_stream.encode(None))
output_container.close()
output_buffer.seek(0)
return output_buffer
def trim_video(video: Input.Video, duration_sec: float) -> Input.Video:
"""
Returns a new VideoInput object trimmed from the beginning to the specified duration,
using av to avoid loading entire video into memory.
Args:
video: Input video to trim
duration_sec: Duration in seconds to keep from the beginning
Returns:
VideoFromFile object that owns the output buffer
"""
output_buffer = BytesIO()
input_container = None
output_container = None
try:
# Get the stream source - this avoids loading entire video into memory
# when the source is already a file path
input_source = video.get_stream_source()
# Open containers
input_container = av.open(input_source, mode="r")
output_container = av.open(output_buffer, mode="w", format="mp4")
# Set up output streams for re-encoding
video_stream = None
audio_stream = None
for stream in input_container.streams:
logging.info("Found stream: type=%s, class=%s", stream.type, type(stream))
if isinstance(stream, av.VideoStream):
# Create output video stream with same parameters
video_stream = output_container.add_stream("h264", rate=stream.average_rate)
video_stream.width = stream.width
video_stream.height = stream.height
video_stream.pix_fmt = "yuv420p"
logging.info("Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate)
elif isinstance(stream, av.AudioStream):
# Create output audio stream with same parameters
audio_stream = output_container.add_stream("aac", rate=stream.sample_rate)
audio_stream.sample_rate = stream.sample_rate
audio_stream.layout = stream.layout
logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels)
# Calculate target frame count that's divisible by 16
fps = input_container.streams.video[0].average_rate
estimated_frames = int(duration_sec * fps)
target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16
if target_frames == 0:
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
frame_count = 0
audio_frame_count = 0
# Decode and re-encode video frames
if video_stream:
for frame in input_container.decode(video=0):
if frame_count >= target_frames:
break
# Re-encode frame
for packet in video_stream.encode(frame):
output_container.mux(packet)
frame_count += 1
# Flush encoder
for packet in video_stream.encode():
output_container.mux(packet)
logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames)
# Decode and re-encode audio frames
if audio_stream:
input_container.seek(0) # Reset to beginning for audio
for frame in input_container.decode(audio=0):
if frame.time >= duration_sec:
break
# Re-encode frame
for packet in audio_stream.encode(frame):
output_container.mux(packet)
audio_frame_count += 1
# Flush encoder
for packet in audio_stream.encode():
output_container.mux(packet)
logging.info("Encoded %s audio frames", audio_frame_count)
# Close containers
output_container.close()
input_container.close()
# Return as VideoFromFile using the buffer
output_buffer.seek(0)
return InputImpl.VideoFromFile(output_buffer)
except Exception as e:
# Clean up on error
if input_container is not None:
input_container.close()
if output_container is not None:
output_container.close()
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
def _f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
if wav.dtype.is_floating_point:
return wav
elif wav.dtype == torch.int16:
return wav.float() / (2**15)
elif wav.dtype == torch.int32:
return wav.float() / (2**31)
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict:
"""
Decode any common audio container from bytes using PyAV and return
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
"""
with av.open(BytesIO(audio_bytes)) as af:
if not af.streams.audio:
raise ValueError("No audio stream found in response.")
stream = af.streams.audio[0]
in_sr = int(stream.codec_context.sample_rate)
out_sr = in_sr
frames: list[torch.Tensor] = []
n_channels = stream.channels or 1
for frame in af.decode(streams=stream.index):
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
buf = torch.from_numpy(arr)
if buf.ndim == 1:
buf = buf.unsqueeze(0) # [T] -> [1, T]
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
elif buf.shape[0] != n_channels:
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
frames.append(buf)
if not frames:
raise ValueError("Decoded zero audio frames.")
wav = torch.cat(frames, dim=1) # [C, T]
wav = _f32_pcm(wav)
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
def resize_mask_to_image(
mask: torch.Tensor,
image: torch.Tensor,
upscale_method="nearest-exact",
crop="disabled",
allow_gradient=True,
add_channel_dim=False,
):
"""Resize mask to be the same dimensions as an image, while maintaining proper format for API calls."""
_, height, width, _ = image.shape
mask = mask.unsqueeze(-1)
mask = mask.movedim(-1, 1)
mask = common_upscale(mask, width=width, height=height, upscale_method=upscale_method, crop=crop)
mask = mask.movedim(1, -1)
if not add_channel_dim:
mask = mask.squeeze(-1)
if not allow_gradient:
mask = (mask > 0.5).float()
return mask
def text_filepath_to_base64_string(filepath: str) -> str:
"""Converts a text file to a base64 string."""
with open(filepath, "rb") as f:
file_content = f.read()
return base64.b64encode(file_content).decode("utf-8")
def text_filepath_to_data_uri(filepath: str) -> str:
"""Converts a text file to a data URI."""
base64_string = text_filepath_to_base64_string(filepath)
mime_type, _ = mimetypes.guess_type(filepath)
if mime_type is None:
mime_type = "application/octet-stream"
return f"data:{mime_type};base64,{base64_string}"

View File

@@ -0,0 +1,262 @@
import asyncio
import contextlib
import uuid
from io import BytesIO
from pathlib import Path
from typing import IO, Optional, Union
from urllib.parse import urljoin, urlparse
import aiohttp
import torch
from aiohttp.client_exceptions import ClientError, ContentTypeError
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import IO as COMFY_IO
from comfy_api_nodes.apis import request_logger
from ._helpers import (
default_base_url,
get_auth_header,
is_processing_interrupted,
sleep_with_interrupt,
)
from .client import _diagnose_connectivity
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
from .conversions import bytesio_to_image_tensor
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
async def download_url_to_bytesio(
url: str,
dest: Optional[Union[BytesIO, IO[bytes], str, Path]],
*,
timeout: Optional[float] = None,
max_retries: int = 5,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
cls: type[COMFY_IO.ComfyNode] = None,
) -> None:
"""Stream-download a URL to `dest`.
`dest` must be one of:
- a BytesIO (rewound to 0 after write),
- a file-like object opened in binary write mode (must implement .write()),
- a filesystem path (str | pathlib.Path), which will be opened with 'wb'.
If `url` starts with `/proxy/`, `cls` must be provided so the URL can be expanded
to an absolute URL and authentication headers can be applied.
Raises:
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors)
"""
if not isinstance(dest, (str, Path)) and not hasattr(dest, "write"):
raise ValueError("dest must be a path (str|Path) or a binary-writable object providing .write().")
attempt = 0
delay = retry_delay
headers: dict[str, str] = {}
parsed_url = urlparse(url)
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
if cls is None:
raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.")
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
headers = get_auth_header(cls)
while True:
attempt += 1
op_id = _generate_operation_id("GET", url, attempt)
timeout_cfg = aiohttp.ClientTimeout(total=timeout)
is_path_sink = isinstance(dest, (str, Path))
fhandle = None
session: Optional[aiohttp.ClientSession] = None
stop_evt: Optional[asyncio.Event] = None
monitor_task: Optional[asyncio.Task] = None
req_task: Optional[asyncio.Task] = None
try:
with contextlib.suppress(Exception):
request_logger.log_request_response(operation_id=op_id, request_method="GET", request_url=url)
session = aiohttp.ClientSession(timeout=timeout_cfg)
stop_evt = asyncio.Event()
async def _monitor():
try:
while not stop_evt.is_set():
if is_processing_interrupted():
return
await asyncio.sleep(1.0)
except asyncio.CancelledError:
return
monitor_task = asyncio.create_task(_monitor())
req_task = asyncio.create_task(session.get(url, headers=headers))
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
if monitor_task in done and req_task in pending:
req_task.cancel()
with contextlib.suppress(Exception):
await req_task
raise ProcessingInterrupted("Task cancelled")
try:
resp = await req_task
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
async with resp:
if resp.status >= 400:
with contextlib.suppress(Exception):
try:
body = await resp.json()
except (ContentTypeError, ValueError):
text = await resp.text()
body = text if len(text) <= 4096 else f"[text {len(text)} bytes]"
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
error_message=f"HTTP {resp.status}",
)
if resp.status in _RETRY_STATUS and attempt <= max_retries:
await sleep_with_interrupt(delay, cls, None, None, None)
delay *= retry_backoff
continue
raise Exception(f"Failed to download (HTTP {resp.status}).")
if is_path_sink:
p = Path(str(dest))
with contextlib.suppress(Exception):
p.parent.mkdir(parents=True, exist_ok=True)
fhandle = open(p, "wb")
sink = fhandle
else:
sink = dest # BytesIO or file-like
written = 0
while True:
try:
chunk = await asyncio.wait_for(resp.content.read(1024 * 1024), timeout=1.0)
except asyncio.TimeoutError:
chunk = b""
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
if is_processing_interrupted():
raise ProcessingInterrupted("Task cancelled")
if not chunk:
if resp.content.at_eof():
break
continue
sink.write(chunk)
written += len(chunk)
if isinstance(dest, BytesIO):
with contextlib.suppress(Exception):
dest.seek(0)
with contextlib.suppress(Exception):
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=f"[streamed {written} bytes to dest]",
)
return
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
except (ClientError, OSError) as e:
if attempt <= max_retries:
with contextlib.suppress(Exception):
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
request_url=url,
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
await sleep_with_interrupt(delay, cls, None, None, None)
delay *= retry_backoff
continue
diag = await _diagnose_connectivity()
if not diag["internet_accessible"]:
raise LocalNetworkError(
"Unable to connect to the network. Please check your internet connection and try again."
) from e
raise ApiServerError("The remote service appears unreachable at this time.") from e
finally:
if stop_evt is not None:
stop_evt.set()
if monitor_task:
monitor_task.cancel()
with contextlib.suppress(Exception):
await monitor_task
if req_task and not req_task.done():
req_task.cancel()
with contextlib.suppress(Exception):
await req_task
if session:
with contextlib.suppress(Exception):
await session.close()
if fhandle:
with contextlib.suppress(Exception):
fhandle.flush()
fhandle.close()
async def download_url_to_image_tensor(
url: str,
*,
timeout: float = None,
cls: type[COMFY_IO.ComfyNode] = None,
) -> torch.Tensor:
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
result = BytesIO()
await download_url_to_bytesio(url, result, timeout=timeout, cls=cls)
return bytesio_to_image_tensor(result)
async def download_url_to_video_output(
video_url: str,
*,
timeout: float = None,
max_retries: int = 5,
cls: type[COMFY_IO.ComfyNode] = None,
) -> VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output."""
result = BytesIO()
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
return VideoFromFile(result)
async def download_url_as_bytesio(
url: str,
*,
timeout: float = None,
cls: type[COMFY_IO.ComfyNode] = None,
) -> BytesIO:
"""Downloads content from a URL and returns a new BytesIO (rewound to 0)."""
result = BytesIO()
await download_url_to_bytesio(url, result, timeout=timeout, cls=cls)
return result
def _generate_operation_id(method: str, url: str, attempt: int) -> str:
try:
parsed = urlparse(url)
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download").strip("/").replace("/", "_")
except Exception:
slug = "download"
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"

View File

@@ -0,0 +1,338 @@
import asyncio
import contextlib
import logging
import time
import uuid
from io import BytesIO
from typing import Optional, Union
from urllib.parse import urlparse
import aiohttp
import torch
from pydantic import BaseModel, Field
from comfy_api.latest import IO, Input
from comfy_api.util import VideoCodec, VideoContainer
from comfy_api_nodes.apis import request_logger
from ._helpers import is_processing_interrupted, sleep_with_interrupt
from .client import (
ApiEndpoint,
_diagnose_connectivity,
_display_time_progress,
sync_op,
)
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
from .conversions import (
audio_ndarray_to_bytesio,
audio_tensor_to_contiguous_ndarray,
tensor_to_bytesio,
)
class UploadRequest(BaseModel):
file_name: str = Field(..., description="Filename to upload")
content_type: Optional[str] = Field(
None,
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
)
class UploadResponse(BaseModel):
download_url: str = Field(..., description="URL to GET uploaded file")
upload_url: str = Field(..., description="URL to PUT file to upload")
async def upload_images_to_comfyapi(
cls: type[IO.ComfyNode],
image: torch.Tensor,
*,
max_images: int = 8,
mime_type: Optional[str] = None,
wait_label: Optional[str] = "Uploading",
) -> list[str]:
"""
Uploads images to ComfyUI API and returns download URLs.
To upload multiple images, stack them in the batch dimension first.
"""
# if batch, try to upload each file if max_images is greater than 0
download_urls: list[str] = []
is_batch = len(image.shape) > 3
batch_len = image.shape[0] if is_batch else 1
for idx in range(min(batch_len, max_images)):
tensor = image[idx] if is_batch else image
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label)
download_urls.append(url)
return download_urls
async def upload_audio_to_comfyapi(
cls: type[IO.ComfyNode],
audio: Input.Audio,
*,
container_format: str = "mp4",
codec_name: str = "aac",
mime_type: str = "audio/mp4",
filename: str = "uploaded_audio.mp4",
) -> str:
"""
Uploads a single audio input to ComfyUI API and returns its download URL.
Encodes the raw waveform into the specified format before uploading.
"""
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type)
async def upload_video_to_comfyapi(
cls: type[IO.ComfyNode],
video: Input.Video,
*,
container: VideoContainer = VideoContainer.MP4,
codec: VideoCodec = VideoCodec.H264,
max_duration: Optional[int] = None,
) -> str:
"""
Uploads a single video to ComfyUI API and returns its download URL.
Uses the specified container and codec for saving the video before upload.
"""
if max_duration is not None:
try:
actual_duration = video.get_duration()
if actual_duration > max_duration:
raise ValueError(
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
)
except Exception as e:
logging.error("Error getting video duration: %s", str(e))
raise ValueError(f"Could not verify video duration from source: {e}") from e
upload_mime_type = f"video/{container.value.lower()}"
filename = f"uploaded_video.{container.value.lower()}"
# Convert VideoInput to BytesIO using specified container/codec
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=container, codec=codec)
video_bytes_io.seek(0)
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type)
async def upload_file_to_comfyapi(
cls: type[IO.ComfyNode],
file_bytes_io: BytesIO,
filename: str,
upload_mime_type: Optional[str],
wait_label: Optional[str] = "Uploading",
) -> str:
"""Uploads a single file to ComfyUI API and returns its download URL."""
if upload_mime_type is None:
request_object = UploadRequest(file_name=filename)
else:
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
create_resp = await sync_op(
cls,
endpoint=ApiEndpoint(path="/customers/storage", method="POST"),
data=request_object,
response_model=UploadResponse,
final_label_on_success=None,
monitor_progress=False,
)
await upload_file(
cls,
create_resp.upload_url,
file_bytes_io,
content_type=upload_mime_type,
wait_label=wait_label,
)
return create_resp.download_url
async def upload_file(
cls: type[IO.ComfyNode],
upload_url: str,
file: Union[BytesIO, str],
*,
content_type: Optional[str] = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: Optional[str] = None,
) -> None:
"""
Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption.
Args:
cls: Node class (provides auth context + UI progress hooks).
upload_url: Pre-signed PUT URL.
file: BytesIO or path string.
content_type: Explicit MIME type. If None, we *suppress* Content-Type.
max_retries: Maximum retry attempts.
retry_delay: Initial delay in seconds.
retry_backoff: Exponential backoff factor.
wait_label: Progress label shown in Comfy UI.
Raises:
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception
"""
if isinstance(file, BytesIO):
with contextlib.suppress(Exception):
file.seek(0)
data = file.read()
elif isinstance(file, str):
with open(file, "rb") as f:
data = f.read()
else:
raise ValueError("file must be a BytesIO or a filesystem path string")
headers: dict[str, str] = {}
skip_auto_headers: set[str] = set()
if content_type:
headers["Content-Type"] = content_type
else:
skip_auto_headers.add("Content-Type") # Don't let aiohttp add Content-Type, it can break the signed request
attempt = 0
delay = retry_delay
start_ts = time.monotonic()
op_uuid = uuid.uuid4().hex[:8]
while True:
attempt += 1
operation_id = _generate_operation_id("PUT", upload_url, attempt, op_uuid)
timeout = aiohttp.ClientTimeout(total=None)
stop_evt = asyncio.Event()
async def _monitor():
try:
while not stop_evt.is_set():
if is_processing_interrupted():
return
if wait_label:
_display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None)
await asyncio.sleep(1.0)
except asyncio.CancelledError:
return
monitor_task = asyncio.create_task(_monitor())
sess: Optional[aiohttp.ClientSession] = None
try:
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers or None,
request_params=None,
request_data=f"[File data {len(data)} bytes]",
)
except Exception as e:
logging.debug("[DEBUG] upload request logging failed: %s", e)
sess = aiohttp.ClientSession(timeout=timeout)
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
req_task = asyncio.create_task(req)
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
if monitor_task in done and req_task in pending:
req_task.cancel()
raise ProcessingInterrupted("Upload cancelled")
try:
resp = await req_task
except asyncio.CancelledError:
raise ProcessingInterrupted("Upload cancelled") from None
async with resp:
if resp.status >= 400:
with contextlib.suppress(Exception):
try:
body = await resp.json()
except Exception:
body = await resp.text()
msg = f"Upload failed with status {resp.status}"
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
error_message=msg,
)
if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries:
await sleep_with_interrupt(
delay,
cls,
wait_label,
start_ts,
None,
display_callback=_display_time_progress if wait_label else None,
)
delay *= retry_backoff
continue
raise Exception(f"Failed to upload (HTTP {resp.status}).")
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content="File uploaded successfully.",
)
except Exception as e:
logging.debug("[DEBUG] upload response logging failed: %s", e)
return
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
except (aiohttp.ClientError, OSError) as e:
if attempt <= max_retries:
with contextlib.suppress(Exception):
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers or None,
request_data=f"[File data {len(data)} bytes]",
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
await sleep_with_interrupt(
delay,
cls,
wait_label,
start_ts,
None,
display_callback=_display_time_progress if wait_label else None,
)
delay *= retry_backoff
continue
diag = await _diagnose_connectivity()
if not diag["internet_accessible"]:
raise LocalNetworkError(
"Unable to connect to the network. Please check your internet connection and try again."
) from e
raise ApiServerError("The API service appears unreachable at this time.") from e
finally:
stop_evt.set()
if monitor_task:
monitor_task.cancel()
with contextlib.suppress(Exception):
await monitor_task
if sess:
with contextlib.suppress(Exception):
await sess.close()
def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str:
try:
parsed = urlparse(url)
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_")
except Exception:
slug = "upload"
return f"{method}_{slug}_{op_uuid}_try{attempt}"

View File

@@ -2,6 +2,8 @@ import logging
from typing import Optional
import torch
from comfy_api.input.video_types import VideoInput
from comfy_api.latest import Input
@@ -28,76 +30,69 @@ def validate_image_dimensions(
if max_width is not None and width > max_width:
raise ValueError(f"Image width must be at most {max_width}px, got {width}px")
if min_height is not None and height < min_height:
raise ValueError(
f"Image height must be at least {min_height}px, got {height}px"
)
raise ValueError(f"Image height must be at least {min_height}px, got {height}px")
if max_height is not None and height > max_height:
raise ValueError(f"Image height must be at most {max_height}px, got {height}px")
def validate_image_aspect_ratio(
image: torch.Tensor,
min_aspect_ratio: Optional[float] = None,
max_aspect_ratio: Optional[float] = None,
):
width, height = get_image_dimensions(image)
aspect_ratio = width / height
if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
raise ValueError(
f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}"
)
if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
raise ValueError(
f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}"
)
def validate_image_aspect_ratio_range(
image: torch.Tensor,
min_ratio: tuple[float, float], # e.g. (1, 4)
max_ratio: tuple[float, float], # e.g. (4, 1)
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
*,
strict: bool = True, # True -> (min, max); False -> [min, max]
strict: bool = True, # True -> (min, max); False -> [min, max]
) -> float:
a1, b1 = min_ratio
a2, b2 = max_ratio
if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0:
raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).")
lo, hi = (a1 / b1), (a2 / b2)
if lo > hi:
lo, hi = hi, lo
a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text
"""Validates that image aspect ratio is within min and max. If a bound is None, that side is not checked."""
w, h = get_image_dimensions(image)
if w <= 0 or h <= 0:
raise ValueError(f"Invalid image dimensions: {w}x{h}")
ar = w / h
ok = (lo < ar < hi) if strict else (lo <= ar <= hi)
if not ok:
op = "<" if strict else ""
raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}")
_assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
return ar
def validate_aspect_ratio_closeness(
start_img,
end_img,
min_rel: float,
max_rel: float,
def validate_images_aspect_ratio_closeness(
first_image: torch.Tensor,
second_image: torch.Tensor,
min_rel: float, # e.g. 0.8
max_rel: float, # e.g. 1.25
*,
strict: bool = False, # True => exclusive, False => inclusive
) -> None:
w1, h1 = get_image_dimensions(start_img)
w2, h2 = get_image_dimensions(end_img)
strict: bool = False, # True -> (min, max); False -> [min, max]
) -> float:
"""
Validates that the two images' aspect ratios are 'close'.
The closeness factor is C = max(ar1, ar2) / min(ar1, ar2) (C >= 1).
We require C <= limit, where limit = max(max_rel, 1.0 / min_rel).
Returns the computed closeness factor C.
"""
w1, h1 = get_image_dimensions(first_image)
w2, h2 = get_image_dimensions(second_image)
if min(w1, h1, w2, h2) <= 0:
raise ValueError("Invalid image dimensions")
ar1 = w1 / h1
ar2 = w2 / h2
# Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1)
closeness = max(ar1, ar2) / min(ar1, ar2)
limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25
limit = max(max_rel, 1.0 / min_rel)
if (closeness >= limit) if strict else (closeness > limit):
raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}{max_rel}.")
raise ValueError(
f"Aspect ratios must be close: ar1/ar2={ar1/ar2:.2g}, "
f"allowed range {min_rel}{max_rel} (limit {limit:.2g})."
)
return closeness
def validate_aspect_ratio_string(
aspect_ratio: str,
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
*,
strict: bool = False, # True -> (min, max); False -> [min, max]
) -> float:
"""Parses 'X:Y' and validates it against optional bounds. Returns the numeric ratio."""
ar = _parse_aspect_ratio_string(aspect_ratio)
_assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
return ar
def validate_video_dimensions(
@@ -118,9 +113,7 @@ def validate_video_dimensions(
if max_width is not None and width > max_width:
raise ValueError(f"Video width must be at most {max_width}px, got {width}px")
if min_height is not None and height < min_height:
raise ValueError(
f"Video height must be at least {min_height}px, got {height}px"
)
raise ValueError(f"Video height must be at least {min_height}px, got {height}px")
if max_height is not None and height > max_height:
raise ValueError(f"Video height must be at most {max_height}px, got {height}px")
@@ -138,13 +131,9 @@ def validate_video_duration(
epsilon = 0.0001
if min_duration is not None and min_duration - epsilon > duration:
raise ValueError(
f"Video duration must be at least {min_duration}s, got {duration}s"
)
raise ValueError(f"Video duration must be at least {min_duration}s, got {duration}s")
if max_duration is not None and duration > max_duration + epsilon:
raise ValueError(
f"Video duration must be at most {max_duration}s, got {duration}s"
)
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
def get_number_of_images(images):
@@ -165,3 +154,77 @@ def validate_audio_duration(
raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s")
if max_duration is not None and dur - eps > max_duration:
raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s")
def validate_string(
string: str,
strip_whitespace=True,
field_name="prompt",
min_length=None,
max_length=None,
):
if string is None:
raise Exception(f"Field '{field_name}' cannot be empty.")
if strip_whitespace:
string = string.strip()
if min_length and len(string) < min_length:
raise Exception(
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
)
if max_length and len(string) > max_length:
raise Exception(
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
)
def validate_container_format_is_mp4(video: VideoInput) -> None:
"""Validates video container format is MP4."""
container_format = video.get_container_format()
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
def _ratio_from_tuple(r: tuple[float, float]) -> float:
a, b = r
if a <= 0 or b <= 0:
raise ValueError(f"Ratios must be positive, got {a}:{b}.")
return a / b
def _assert_ratio_bounds(
ar: float,
*,
min_ratio: Optional[tuple[float, float]] = None,
max_ratio: Optional[tuple[float, float]] = None,
strict: bool = True,
) -> None:
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
lo = _ratio_from_tuple(min_ratio) if min_ratio is not None else None
hi = _ratio_from_tuple(max_ratio) if max_ratio is not None else None
if lo is not None and hi is not None and lo > hi:
lo, hi = hi, lo # normalize order if caller swapped them
if lo is not None:
if (ar <= lo) if strict else (ar < lo):
op = "<" if strict else ""
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {lo:.2g}.")
if hi is not None:
if (ar >= hi) if strict else (ar > hi):
op = "<" if strict else ""
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {hi:.2g}.")
def _parse_aspect_ratio_string(ar_str: str) -> float:
"""Parse 'X:Y' with integer parts into a positive float ratio X/Y."""
parts = ar_str.split(":")
if len(parts) != 2:
raise ValueError(f"Aspect ratio must be 'X:Y' (e.g., 16:9), got '{ar_str}'.")
try:
a = int(parts[0].strip())
b = int(parts[1].strip())
except ValueError as exc:
raise ValueError(f"Aspect ratio must contain integers separated by ':', got '{ar_str}'.") from exc
if a <= 0 or b <= 0:
raise ValueError(f"Aspect ratio parts must be positive integers, got {a}:{b}.")
return a / b

View File

@@ -1,4 +1,9 @@
import bisect
import gc
import itertools
import psutil
import time
import torch
from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt
from abc import ABC, abstractmethod
@@ -48,7 +53,7 @@ class Unhashable:
def to_hashable(obj):
# So that we don't infinitely recurse since frozenset and tuples
# are Sequences.
if isinstance(obj, (int, float, str, bool, type(None))):
if isinstance(obj, (int, float, str, bool, bytes, type(None))):
return obj
elif isinstance(obj, Mapping):
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
@@ -188,6 +193,9 @@ class BasicCache:
self._clean_cache()
self._clean_subcaches()
def poll(self, **kwargs):
pass
def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
@@ -265,6 +273,29 @@ class HierarchicalCache(BasicCache):
assert cache is not None
return await cache._ensure_subcache(node_id, children_ids)
class NullCache:
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
pass
def all_node_ids(self):
return []
def clean_unused(self):
pass
def poll(self, **kwargs):
pass
def get(self, node_id):
return None
def set(self, node_id, value):
pass
async def ensure_subcache_for(self, node_id, children_ids):
return self
class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100):
super().__init__(key_class)
@@ -318,155 +349,75 @@ class LRUCache(BasicCache):
return self
class DependencyAwareCache(BasicCache):
"""
A cache implementation that tracks dependencies between nodes and manages
their execution and caching accordingly. It extends the BasicCache class.
Nodes are removed from this cache once all of their descendants have been
executed.
"""
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
RAM_CACHE_HYSTERESIS = 1.1
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
#Exponential bias towards evicting older workflows so garbage will be taken out
#in constantly changing setups.
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
class RAMPressureCache(LRUCache):
def __init__(self, key_class):
"""
Initialize the DependencyAwareCache.
Args:
key_class: The class used for generating cache keys.
"""
super().__init__(key_class)
self.descendants = {} # Maps node_id -> set of descendant node_ids
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
self.executed_nodes = set() # Tracks nodes that have been executed
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
"""
Clear the entire cache and rebuild the dependency graph.
Args:
dynprompt: The dynamic prompt object containing node information.
node_ids: List of node IDs to initialize the cache for.
is_changed_cache: Flag indicating if the cache has changed.
"""
# Clear all existing cache data
self.cache.clear()
self.subcaches.clear()
self.descendants.clear()
self.ancestors.clear()
self.executed_nodes.clear()
# Call the parent method to initialize the cache with the new prompt
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
# Rebuild the dependency graph
self._build_dependency_graph(dynprompt, node_ids)
def _build_dependency_graph(self, dynprompt, node_ids):
"""
Build the dependency graph for all nodes.
Args:
dynprompt: The dynamic prompt object containing node information.
node_ids: List of node IDs to build the graph for.
"""
self.descendants.clear()
self.ancestors.clear()
for node_id in node_ids:
self.descendants[node_id] = set()
self.ancestors[node_id] = set()
for node_id in node_ids:
inputs = dynprompt.get_node(node_id)["inputs"]
for input_data in inputs.values():
if is_link(input_data): # Check if the input is a link to another node
ancestor_id = input_data[0]
self.descendants[ancestor_id].add(node_id)
self.ancestors[node_id].add(ancestor_id)
def set(self, node_id, value):
"""
Mark a node as executed and store its value in the cache.
Args:
node_id: The ID of the node to store.
value: The value to store for the node.
"""
self._set_immediate(node_id, value)
self.executed_nodes.add(node_id)
self._cleanup_ancestors(node_id)
def get(self, node_id):
"""
Retrieve the cached value for a node.
Args:
node_id: The ID of the node to retrieve.
Returns:
The cached value for the node.
"""
return self._get_immediate(node_id)
async def ensure_subcache_for(self, node_id, children_ids):
"""
Ensure a subcache exists for a node and update dependencies.
Args:
node_id: The ID of the parent node.
children_ids: List of child node IDs to associate with the parent node.
Returns:
The subcache object for the node.
"""
subcache = await super()._ensure_subcache(node_id, children_ids)
for child_id in children_ids:
self.descendants[node_id].add(child_id)
self.ancestors[child_id].add(node_id)
return subcache
def _cleanup_ancestors(self, node_id):
"""
Check if ancestors of a node can be removed from the cache.
Args:
node_id: The ID of the node whose ancestors are to be checked.
"""
for ancestor_id in self.ancestors.get(node_id, []):
if ancestor_id in self.executed_nodes:
# Remove ancestor if all its descendants have been executed
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
self._remove_node(ancestor_id)
def _remove_node(self, node_id):
"""
Remove a node from the cache.
Args:
node_id: The ID of the node to remove.
"""
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
del self.cache[cache_key]
subcache_key = self.cache_key_set.get_subcache_key(node_id)
if subcache_key in self.subcaches:
del self.subcaches[subcache_key]
super().__init__(key_class, 0)
self.timestamps = {}
def clean_unused(self):
"""
Clean up unused nodes. This is a no-op for this cache implementation.
"""
pass
self._clean_subcaches()
def recursive_debug_dump(self):
"""
Dump the cache and dependency graph for debugging.
def set(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
super().set(node_id, value)
Returns:
A list containing the cache state and dependency graph.
"""
result = super().recursive_debug_dump()
result.append({
"descendants": self.descendants,
"ancestors": self.ancestors,
"executed_nodes": list(self.executed_nodes),
})
return result
def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return super().get(node_id)
def poll(self, ram_headroom):
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)
if _ram_gb() > ram_headroom:
return
gc.collect()
if _ram_gb() > ram_headroom:
return
clean_list = []
for key, (outputs, _), in self.cache.items():
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
def scan_list_for_ram_usage(outputs):
nonlocal ram_usage
if outputs is None:
return
for output in outputs:
if isinstance(output, list):
scan_list_for_ram_usage(output)
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
#score Tensors at a 50% discount for RAM usage as they are likely to
#be high value intermediates
ram_usage += (output.numel() * output.element_size()) * 0.5
elif hasattr(output, "get_ram_usage"):
ram_usage += output.get_ram_usage()
scan_list_for_ram_usage(outputs)
oom_score *= ram_usage
#In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
_, _, key = clean_list.pop()
del self.cache[key]
gc.collect()

View File

@@ -153,8 +153,9 @@ class TopologicalSort:
continue
_, _, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
node_ids.append(from_node_id)
if (include_lazy or not is_lazy):
if not self.is_cached(from_node_id):
node_ids.append(from_node_id)
links.append((from_node_id, from_socket, unique_id))
for link in links:
@@ -194,10 +195,40 @@ class ExecutionList(TopologicalSort):
super().__init__(dynprompt)
self.output_cache = output_cache
self.staged_node_id = None
self.execution_cache = {}
self.execution_cache_listeners = {}
def is_cached(self, node_id):
return self.output_cache.get(node_id) is not None
def cache_link(self, from_node_id, to_node_id):
if not to_node_id in self.execution_cache:
self.execution_cache[to_node_id] = {}
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
if not from_node_id in self.execution_cache_listeners:
self.execution_cache_listeners[from_node_id] = set()
self.execution_cache_listeners[from_node_id].add(to_node_id)
def get_cache(self, from_node_id, to_node_id):
if not to_node_id in self.execution_cache:
return None
value = self.execution_cache[to_node_id].get(from_node_id)
if value is None:
return None
#Write back to the main cache on touch.
self.output_cache.set(from_node_id, value)
return value
def cache_update(self, node_id, value):
if node_id in self.execution_cache_listeners:
for to_node_id in self.execution_cache_listeners[node_id]:
if to_node_id in self.execution_cache:
self.execution_cache[to_node_id][node_id] = value
def add_strong_link(self, from_node_id, from_socket, to_node_id):
super().add_strong_link(from_node_id, from_socket, to_node_id)
self.cache_link(from_node_id, to_node_id)
async def stage_node_execution(self):
assert self.staged_node_id is None
if self.is_empty():
@@ -277,6 +308,8 @@ class ExecutionList(TopologicalSort):
def complete_node_execution(self):
node_id = self.staged_node_id
self.pop_node(node_id)
self.execution_cache.pop(node_id, None)
self.execution_cache_listeners.pop(node_id, None)
self.staged_node_id = None
def get_nodes_in_cycle(self):

View File

@@ -1,20 +1,26 @@
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
import nodes
import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class SetUnionControlNetType:
class SetUnionControlNetType(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"control_net": ("CONTROL_NET", ),
"type": (["auto"] + list(UNION_CONTROLNET_TYPES.keys()),)
}}
def define_schema(cls):
return io.Schema(
node_id="SetUnionControlNetType",
category="conditioning/controlnet",
inputs=[
io.ControlNet.Input("control_net"),
io.Combo.Input("type", options=["auto"] + list(UNION_CONTROLNET_TYPES.keys())),
],
outputs=[
io.ControlNet.Output(),
],
)
CATEGORY = "conditioning/controlnet"
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "set_controlnet_type"
def set_controlnet_type(self, control_net, type):
@classmethod
def execute(cls, control_net, type) -> io.NodeOutput:
control_net = control_net.copy()
type_number = UNION_CONTROLNET_TYPES.get(type, -1)
if type_number >= 0:
@@ -22,27 +28,36 @@ class SetUnionControlNetType:
else:
control_net.set_extra_arg("control_type", [])
return (control_net,)
return io.NodeOutput(control_net)
class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
set_controlnet_type = execute # TODO: remove
class ControlNetInpaintingAliMamaApply(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"control_net": ("CONTROL_NET", ),
"vae": ("VAE", ),
"image": ("IMAGE", ),
"mask": ("MASK", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
def define_schema(cls):
return io.Schema(
node_id="ControlNetInpaintingAliMamaApply",
category="conditioning/controlnet",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.ControlNet.Input("control_net"),
io.Vae.Input("vae"),
io.Image.Input("image"),
io.Mask.Input("mask"),
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
)
FUNCTION = "apply_inpaint_controlnet"
CATEGORY = "conditioning/controlnet"
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
@classmethod
def execute(cls, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent) -> io.NodeOutput:
extra_concat = []
if control_net.concat_mask:
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
@@ -50,11 +65,20 @@ class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
extra_concat = [mask]
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
result = nodes.ControlNetApplyAdvanced().apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
return io.NodeOutput(result[0], result[1])
apply_inpaint_controlnet = execute # TODO: remove
class ControlNetExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SetUnionControlNetType,
ControlNetInpaintingAliMamaApply,
]
NODE_CLASS_MAPPINGS = {
"SetUnionControlNetType": SetUnionControlNetType,
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
}
async def comfy_entrypoint() -> ControlNetExtension:
return ControlNetExtension()

View File

@@ -244,6 +244,8 @@ class EasyCacheHolder:
self.total_steps_skipped += 1
batch_offset = x.shape[0] // len(uuids)
for i, uuid in enumerate(uuids):
# slice out only what is relevant to this cond
batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)]
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
if not self.allow_mismatch:
@@ -261,9 +263,8 @@ class EasyCacheHolder:
slicing.append(slice(None, dim_u))
else:
slicing.append(slice(None))
slicing = [slice(i*batch_offset,(i+1)*batch_offset)] + slicing
x = x[slicing]
x += self.uuid_cache_diffs[uuid].to(x.device)
batch_slice = batch_slice + slicing
x[batch_slice] += self.uuid_cache_diffs[uuid].to(x.device)
return x
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):

View File

@@ -2,6 +2,9 @@ import comfy.utils
import folder_paths
import torch
import logging
from comfy_api.latest import IO, ComfyExtension
from typing_extensions import override
def load_hypernetwork_patch(path, strength):
sd = comfy.utils.load_torch_file(path, safe_load=True)
@@ -94,27 +97,42 @@ def load_hypernetwork_patch(path, strength):
return hypernetwork_patch(out, strength)
class HypernetworkLoader:
class HypernetworkLoader(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_hypernetwork"
def define_schema(cls):
return IO.Schema(
node_id="HypernetworkLoader",
category="loaders",
inputs=[
IO.Model.Input("model"),
IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")),
IO.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01),
],
outputs=[
IO.Model.Output(),
],
)
CATEGORY = "loaders"
def load_hypernetwork(self, model, hypernetwork_name, strength):
@classmethod
def execute(cls, model, hypernetwork_name, strength) -> IO.NodeOutput:
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
model_hypernetwork = model.clone()
patch = load_hypernetwork_patch(hypernetwork_path, strength)
if patch is not None:
model_hypernetwork.set_model_attn1_patch(patch)
model_hypernetwork.set_model_attn2_patch(patch)
return (model_hypernetwork,)
return IO.NodeOutput(model_hypernetwork)
NODE_CLASS_MAPPINGS = {
"HypernetworkLoader": HypernetworkLoader
}
load_hypernetwork = execute # TODO: remove
class HyperNetworkExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
HypernetworkLoader,
]
async def comfy_entrypoint() -> HyperNetworkExtension:
return HyperNetworkExtension()

View File

@@ -1,86 +0,0 @@
from __future__ import annotations
from inspect import cleandoc
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
import comfy.multigpu
class MultiGPUWorkUnitsNode:
"""
Prepares model to have sampling accelerated via splitting work units.
Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
Other than those exceptions, this node can be placed in any order.
"""
NodeId = "MultiGPU_WorkUnits"
NodeName = "MultiGPU Work Units"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"max_gpus" : ("INT", {"default": 8, "min": 1, "step": 1}),
},
"optional": {
"gpu_options": ("GPU_OPTIONS",)
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "init_multigpu"
CATEGORY = "advanced/multigpu"
DESCRIPTION = cleandoc(__doc__)
def init_multigpu(self, model: ModelPatcher, max_gpus: int, gpu_options: comfy.multigpu.GPUOptionsGroup=None):
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, gpu_options, reuse_loaded=True)
return (model,)
class MultiGPUOptionsNode:
"""
Select the relative speed of GPUs in the special case they have significantly different performance from one another.
"""
NodeId = "MultiGPU_Options"
NodeName = "MultiGPU Options"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"device_index": ("INT", {"default": 0, "min": 0, "max": 64}),
"relative_speed": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.01})
},
"optional": {
"gpu_options": ("GPU_OPTIONS",)
}
}
RETURN_TYPES = ("GPU_OPTIONS",)
FUNCTION = "create_gpu_options"
CATEGORY = "advanced/multigpu"
DESCRIPTION = cleandoc(__doc__)
def create_gpu_options(self, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup=None):
if not gpu_options:
gpu_options = comfy.multigpu.GPUOptionsGroup()
gpu_options.clone()
opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed)
gpu_options.add(opt)
return (gpu_options,)
node_list = [
MultiGPUWorkUnitsNode,
MultiGPUOptionsNode
]
NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}
for node in node_list:
NODE_CLASS_MAPPINGS[node.NodeId] = node
NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName

View File

@@ -0,0 +1,47 @@
from comfy_api.latest import ComfyExtension, io
from typing_extensions import override
class ScaleROPE(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ScaleROPE",
category="advanced/model_patches",
description="Scale and shift the ROPE of the model.",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.Float.Input("scale_x", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_x", default=0.0, min=-256.0, max=256.0, step=0.1),
io.Float.Input("scale_y", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_y", default=0.0, min=-256.0, max=256.0, step=0.1),
io.Float.Input("scale_t", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_t", default=0.0, min=-256.0, max=256.0, step=0.1),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t) -> io.NodeOutput:
m = model.clone()
m.set_model_rope_options(scale_x, shift_x, scale_y, shift_y, scale_t, shift_t)
return io.NodeOutput(m)
class RopeExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ScaleROPE
]
async def comfy_entrypoint() -> RopeExtension:
return RopeExtension()

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.65"
__version__ = "0.3.68"

View File

@@ -1,6 +1,6 @@
import os
import importlib.util
from comfy.cli_args import args
from comfy.cli_args import args, PerformanceFeature
import subprocess
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
@@ -75,8 +75,9 @@ if not args.cuda_malloc:
spec.loader.exec_module(module)
version = module.__version__
if int(version[0]) >= 2 and "+cu" in version: #enable by default for torch version 2.0 and up only on cuda torch
args.cuda_malloc = cuda_malloc_supported()
if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch
if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc
args.cuda_malloc = cuda_malloc_supported()
except:
pass

View File

@@ -18,9 +18,10 @@ from comfy_execution.caching import (
BasicCache,
CacheKeySetID,
CacheKeySetInputSignature,
DependencyAwareCache,
NullCache,
HierarchicalCache,
LRUCache,
RAMPressureCache,
)
from comfy_execution.graph import (
DynamicPrompt,
@@ -88,54 +89,62 @@ class IsChangedCache:
return self.is_changed[node_id]
class CacheEntry(NamedTuple):
ui: dict
outputs: list
class CacheType(Enum):
CLASSIC = 0
LRU = 1
DEPENDENCY_AWARE = 2
NONE = 2
RAM_PRESSURE = 3
class CacheSet:
def __init__(self, cache_type=None, cache_size=None):
if cache_type == CacheType.DEPENDENCY_AWARE:
self.init_dependency_aware_cache()
def __init__(self, cache_type=None, cache_args={}):
if cache_type == CacheType.NONE:
self.init_null_cache()
logging.info("Disabling intermediate node cache.")
elif cache_type == CacheType.RAM_PRESSURE:
cache_ram = cache_args.get("ram", 16.0)
self.init_ram_cache(cache_ram)
logging.info("Using RAM pressure cache.")
elif cache_type == CacheType.LRU:
if cache_size is None:
cache_size = 0
cache_size = cache_args.get("lru", 0)
self.init_lru_cache(cache_size)
logging.info("Using LRU cache")
else:
self.init_classic_cache()
self.all = [self.outputs, self.ui, self.objects]
self.all = [self.outputs, self.objects]
# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
# only hold cached items while the decendents have not executed
def init_dependency_aware_cache(self):
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
self.objects = DependencyAwareCache(CacheKeySetID)
def init_ram_cache(self, min_headroom):
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)
def init_null_cache(self):
self.outputs = NullCache()
self.objects = NullCache()
def recursive_debug_dump(self):
result = {
"outputs": self.outputs.recursive_debug_dump(),
"ui": self.ui.recursive_debug_dump(),
}
return result
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
is_v3 = issubclass(class_def, _ComfyNodeInternal)
if is_v3:
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
@@ -153,17 +162,17 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
input_unique_id = input_data[0]
output_index = input_data[1]
if outputs is None:
if execution_list is None:
mark_missing()
continue # This might be a lazily-evaluated input
cached_output = outputs.get(input_unique_id)
if cached_output is None:
cached = execution_list.get_cache(input_unique_id, unique_id)
if cached is None or cached.outputs is None:
mark_missing()
continue
if output_index >= len(cached_output):
if output_index >= len(cached.outputs):
mark_missing()
continue
obj = cached_output[output_index]
obj = cached.outputs[output_index]
input_data_all[x] = obj
elif input_category is not None:
input_data_all[x] = [input_data]
@@ -392,7 +401,7 @@ def format_value(x):
else:
return str(x)
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes):
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
@@ -400,11 +409,15 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if caches.outputs.get(unique_id) is not None:
cached = caches.outputs.get(unique_id)
if cached is not None:
if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
cached_ui = cached.ui or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
if cached.ui is not None:
ui_outputs[unique_id] = cached.ui
get_progress_state().finish_progress(unique_id)
execution_list.cache_update(unique_id, cached)
return (ExecutionResult.SUCCESS, None, None)
input_data_all = None
@@ -434,8 +447,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
for r in result:
if is_link(r):
source_node, source_output = r[0], r[1]
node_output = caches.outputs.get(source_node)[source_output]
for o in node_output:
node_cached = execution_list.get_cache(source_node, unique_id)
for o in node_cached.outputs[source_output]:
resolved_output.append(o)
else:
@@ -443,10 +456,11 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
resolved_outputs.append(tuple(resolved_output))
output_data = merge_result_data(resolved_outputs, class_def)
output_ui = []
del pending_subgraph_results[unique_id]
has_subgraph = False
else:
get_progress_state().start_progress(unique_id)
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
@@ -504,7 +518,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0:
caches.ui.set(unique_id, {
ui_outputs[unique_id] = {
"meta": {
"node_id": unique_id,
"display_node": display_node_id,
@@ -512,7 +526,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
"real_node_id": real_node_id,
},
"output": output_ui
})
}
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
if has_subgraph:
@@ -525,10 +539,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
if new_graph is None:
cached_outputs.append((False, node_outputs))
else:
# Check for conflicts
for node_id in new_graph.keys():
if dynprompt.has_node(node_id):
raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.")
for node_id, node_info in new_graph.items():
new_node_ids.append(node_id)
display_id = node_info.get("override_display_id", unique_id)
@@ -549,11 +559,16 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
subcache.clean_unused()
for node_id in new_output_ids:
execution_list.add_node(node_id)
execution_list.cache_link(node_id, unique_id)
for link in new_output_links:
execution_list.add_strong_link(link[0], link[1], unique_id)
pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.PENDING, None, None)
caches.outputs.set(unique_id, output_data)
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
execution_list.cache_update(unique_id, cache_entry)
caches.outputs.set(unique_id, cache_entry)
except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted")
@@ -597,14 +612,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
return (ExecutionResult.SUCCESS, None, None)
class PromptExecutor:
def __init__(self, server, cache_type=False, cache_size=None):
self.cache_size = cache_size
def __init__(self, server, cache_type=False, cache_args=None):
self.cache_args = cache_args
self.cache_type = cache_type
self.server = server
self.reset()
def reset(self):
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
self.status_messages = []
self.success = True
@@ -679,6 +694,7 @@ class PromptExecutor:
broadcast=False)
pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
ui_node_outputs = {}
executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
@@ -692,7 +708,7 @@ class PromptExecutor:
break
assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
@@ -701,18 +717,16 @@ class PromptExecutor:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
ui_outputs = {}
meta_outputs = {}
all_node_ids = self.caches.ui.all_node_ids()
for node_id in all_node_ids:
ui_info = self.caches.ui.get(node_id)
if ui_info is not None:
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
for node_id, ui_info in ui_node_outputs.items():
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = {
"outputs": ui_outputs,
"meta": meta_outputs,
@@ -1110,7 +1124,7 @@ class PromptQueue:
messages: List[str]
def task_done(self, item_id, history_result,
status: Optional['PromptQueue.ExecutionStatus']):
status: Optional['PromptQueue.ExecutionStatus'], process_item=None):
with self.mutex:
prompt = self.currently_running.pop(item_id)
if len(self.history) > MAXIMUM_HISTORY_SIZE:
@@ -1120,10 +1134,8 @@ class PromptQueue:
if status is not None:
status_dict = copy.deepcopy(status._asdict())
# Remove sensitive data from extra_data before storing in history
for sensitive_val in SENSITIVE_EXTRA_DATA_KEYS:
if sensitive_val in prompt[3]:
prompt[3].pop(sensitive_val)
if process_item is not None:
prompt = process_item(prompt)
self.history[prompt[1]] = {
"prompt": prompt,

17
main.py
View File

@@ -172,10 +172,12 @@ def prompt_worker(q, server_instance):
cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0:
cache_type = execution.CacheType.LRU
elif args.cache_ram > 0:
cache_type = execution.CacheType.RAM_PRESSURE
elif args.cache_none:
cache_type = execution.CacheType.DEPENDENCY_AWARE
cache_type = execution.CacheType.NONE
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0
@@ -192,14 +194,21 @@ def prompt_worker(q, server_instance):
prompt_id = item[1]
server_instance.last_prompt_id = prompt_id
e.execute(item[2], prompt_id, item[3], item[4])
sensitive = item[5]
extra_data = item[3].copy()
for k in sensitive:
extra_data[k] = sensitive[k]
e.execute(item[2], prompt_id, extra_data, item[4])
need_gc = True
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
q.task_done(item_id,
e.history_result,
status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error',
completed=e.success,
messages=e.status_messages))
messages=e.status_messages), process_item=remove_sensitive)
if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)

View File

@@ -2304,7 +2304,6 @@ async def init_builtin_extra_nodes():
"nodes_mahiro.py",
"nodes_lt.py",
"nodes_hooks.py",
"nodes_multigpu.py",
"nodes_load_3d.py",
"nodes_cosmos.py",
"nodes_video.py",
@@ -2330,6 +2329,7 @@ async def init_builtin_extra_nodes():
"nodes_model_patch.py",
"nodes_easycache.py",
"nodes_audio_encoder.py",
"nodes_rope.py",
]
import_failed = []
@@ -2350,6 +2350,7 @@ async def init_builtin_api_nodes():
"nodes_kling.py",
"nodes_bfl.py",
"nodes_bytedance.py",
"nodes_ltxv.py",
"nodes_luma.py",
"nodes_recraft.py",
"nodes_pixverse.py",

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.65"
version = "0.3.68"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"
@@ -50,6 +50,8 @@ messages_control.disable = [
"too-many-branches",
"too-many-locals",
"too-many-arguments",
"too-many-return-statements",
"too-many-nested-blocks",
"duplicate-code",
"abstract-method",
"superfluous-parens",

View File

@@ -1,6 +1,6 @@
comfyui-frontend-package==1.28.6
comfyui-workflow-templates==0.1.95
comfyui-embedded-docs==0.3.0
comfyui-frontend-package==1.28.8
comfyui-workflow-templates==0.2.11
comfyui-embedded-docs==0.3.1
torch
torchsde
torchvision

View File

@@ -35,6 +35,7 @@ from comfy_api.internal import _ComfyNodeInternal
from app.user_manager import UserManager
from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from app.subgraph_manager import SubgraphManager
from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes
from protocol import BinaryEventTypes
@@ -48,6 +49,28 @@ async def send_socket_catch_exception(function, message):
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
logging.warning("send error: {}".format(err))
# Track deprecated paths that have been warned about to only warn once per file
_deprecated_paths_warned = set()
@web.middleware
async def deprecation_warning(request: web.Request, handler):
"""Middleware to warn about deprecated frontend API paths"""
path = request.path
if path.startswith("/scripts/ui") or path.startswith("/extensions/core/"):
# Only warn once per unique file path
if path not in _deprecated_paths_warned:
_deprecated_paths_warned.add(path)
logging.warning(
f"[DEPRECATION WARNING] Detected import of deprecated legacy API: {path}. "
f"This is likely caused by a custom node extension using outdated APIs. "
f"Please update your extensions or contact the extension author for an updated version."
)
response: web.Response = await handler(request)
return response
@web.middleware
async def compress_body(request: web.Request, handler):
accept_encoding = request.headers.get("Accept-Encoding", "")
@@ -151,6 +174,7 @@ class PromptServer():
self.user_manager = UserManager()
self.model_file_manager = ModelFileManager()
self.custom_node_manager = CustomNodeManager()
self.subgraph_manager = SubgraphManager()
self.internal_routes = InternalRoutes(self)
self.supports = ["custom_nodes_from_web"]
self.prompt_queue = execution.PromptQueue(self)
@@ -159,7 +183,7 @@ class PromptServer():
self.client_session:Optional[aiohttp.ClientSession] = None
self.number = 0
middlewares = [cache_control]
middlewares = [cache_control, deprecation_warning]
if args.enable_compress_response_body:
middlewares.append(compress_body)
@@ -667,8 +691,9 @@ class PromptServer():
async def get_queue(request):
queue_info = {}
current_queue = self.prompt_queue.get_current_queue_volatile()
queue_info['queue_running'] = current_queue[0]
queue_info['queue_pending'] = current_queue[1]
remove_sensitive = lambda queue: [x[:5] for x in queue]
queue_info['queue_running'] = remove_sensitive(current_queue[0])
queue_info['queue_pending'] = remove_sensitive(current_queue[1])
return web.json_response(queue_info)
@routes.post("/prompt")
@@ -704,7 +729,11 @@ class PromptServer():
extra_data["client_id"] = json_data["client_id"]
if valid[0]:
outputs_to_execute = valid[2]
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
sensitive = {}
for sensitive_val in execution.SENSITIVE_EXTRA_DATA_KEYS:
if sensitive_val in extra_data:
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
return web.json_response(response)
else:
@@ -797,6 +826,7 @@ class PromptServer():
self.user_manager.add_routes(self.routes)
self.model_file_manager.add_routes(self.routes)
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
self.app.add_subapp('/internal', self.internal_routes.get_app())
# Prefix every route with /api for easier matching for delegation.

View File

@@ -0,0 +1,232 @@
import unittest
import torch
import sys
import os
# Add comfy to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
def has_gpu():
return torch.cuda.is_available()
from comfy.cli_args import args
if not has_gpu():
args.cpu = True
from comfy import ops
from comfy.quant_ops import QuantizedTensor
class SimpleModel(torch.nn.Module):
def __init__(self, operations=ops.disable_weight_init):
super().__init__()
self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16)
self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16)
self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16)
def forward(self, x):
x = self.layer1(x)
x = torch.nn.functional.relu(x)
x = self.layer2(x)
x = torch.nn.functional.relu(x)
x = self.layer3(x)
return x
class TestMixedPrecisionOps(unittest.TestCase):
def test_all_layers_standard(self):
"""Test that model with no quantization works normally"""
# Configure no quantization
ops.MixedPrecisionOps._layer_quant_config = {}
# Create model
model = SimpleModel(operations=ops.MixedPrecisionOps)
# Initialize weights manually
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16))
model.layer2.weight = torch.nn.Parameter(torch.randn(30, 20, dtype=torch.bfloat16))
model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16))
model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16))
model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16))
# Initialize weight_function and bias_function
for layer in [model.layer1, model.layer2, model.layer3]:
layer.weight_function = []
layer.bias_function = []
# Forward pass
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
output = model(input_tensor)
self.assertEqual(output.shape, (5, 40))
self.assertEqual(output.dtype, torch.bfloat16)
def test_mixed_precision_load(self):
"""Test loading a mixed precision model from state dict"""
# Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard
layer_quant_config = {
"layer1": {
"format": "float8_e4m3fn",
"params": {}
},
"layer3": {
"format": "float8_e4m3fn",
"params": {}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create state dict with mixed precision
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn)
state_dict = {
# Layer 1: FP8 E4M3FN
"layer1.weight": fp8_weight1,
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
"layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32),
# Layer 2: Standard BF16
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
# Layer 3: FP8 E4M3FN
"layer3.weight": fp8_weight3,
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
}
# Create model and load state dict (strict=False because custom loading pops keys)
model = SimpleModel(operations=ops.MixedPrecisionOps)
model.load_state_dict(state_dict, strict=False)
# Verify weights are wrapped in QuantizedTensor
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout")
# Layer 2 should NOT be quantized
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
# Layer 3 should be quantized
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout")
# Verify scales were loaded
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5)
# Forward pass
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
output = model(input_tensor)
self.assertEqual(output.shape, (5, 40))
def test_state_dict_quantized_preserved(self):
"""Test that quantized weights are preserved in state_dict()"""
# Configure mixed precision
layer_quant_config = {
"layer1": {
"format": "float8_e4m3fn",
"params": {}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create and load model
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
state_dict1 = {
"layer1.weight": fp8_weight,
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
"layer1.weight_scale": torch.tensor(3.0, dtype=torch.float32),
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
model = SimpleModel(operations=ops.MixedPrecisionOps)
model.load_state_dict(state_dict1, strict=False)
# Save state dict
state_dict2 = model.state_dict()
# Verify layer1.weight is a QuantizedTensor with scale preserved
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout")
# Verify non-quantized layers are standard tensors
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor)
def test_weight_function_compatibility(self):
"""Test that weight_function (LoRA) works with quantized layers"""
# Configure FP8 quantization
layer_quant_config = {
"layer1": {
"format": "float8_e4m3fn",
"params": {}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create and load model
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
state_dict = {
"layer1.weight": fp8_weight,
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
"layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32),
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
model = SimpleModel(operations=ops.MixedPrecisionOps)
model.load_state_dict(state_dict, strict=False)
# Add a weight function (simulating LoRA)
# This should trigger dequantization during forward pass
def apply_lora(weight):
lora_delta = torch.randn_like(weight) * 0.01
return weight + lora_delta
model.layer1.weight_function.append(apply_lora)
# Forward pass should work with LoRA (triggers weight_function path)
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
output = model(input_tensor)
self.assertEqual(output.shape, (5, 40))
def test_error_handling_unknown_format(self):
"""Test that unknown formats raise error"""
# Configure with unknown format
layer_quant_config = {
"layer1": {
"format": "unknown_format_xyz",
"params": {}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create state dict
state_dict = {
"layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16),
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
model = SimpleModel(operations=ops.MixedPrecisionOps)
with self.assertRaises(KeyError):
model.load_state_dict(state_dict, strict=False)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,190 @@
import unittest
import torch
import sys
import os
# Add comfy to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
def has_gpu():
return torch.cuda.is_available()
from comfy.cli_args import args
if not has_gpu():
args.cpu = True
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
class TestQuantizedTensor(unittest.TestCase):
"""Test the QuantizedTensor subclass with FP8 layout"""
def test_creation(self):
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(2.0)
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.shape, (256, 128))
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt._layout_params['scale'], scale)
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
def test_dequantize(self):
"""Test explicit dequantization"""
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(3.0)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
dequantized = qt.dequantize()
self.assertEqual(dequantized.dtype, torch.float32)
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
def test_from_float(self):
"""Test creating QuantizedTensor from float tensor"""
float_tensor = torch.randn(64, 32, dtype=torch.float32)
scale = torch.tensor(1.5)
qt = QuantizedTensor.from_float(
float_tensor,
"TensorCoreFP8Layout",
scale=scale,
dtype=torch.float8_e4m3fn
)
self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt.shape, (64, 32))
# Verify dequantization gives approximately original values
dequantized = qt.dequantize()
mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean()
self.assertLess(mean_rel_error, 0.1)
class TestGenericUtilities(unittest.TestCase):
"""Test generic utility operations"""
def test_detach(self):
"""Test detach operation on quantized tensor"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
# Detach should return a new QuantizedTensor
qt_detached = qt.detach()
self.assertIsInstance(qt_detached, QuantizedTensor)
self.assertEqual(qt_detached.shape, qt.shape)
self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout")
def test_clone(self):
"""Test clone operation on quantized tensor"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
# Clone should return a new QuantizedTensor
qt_cloned = qt.clone()
self.assertIsInstance(qt_cloned, QuantizedTensor)
self.assertEqual(qt_cloned.shape, qt.shape)
self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout")
# Verify it's a deep copy
self.assertIsNot(qt_cloned._qdata, qt._qdata)
@unittest.skipUnless(has_gpu(), "GPU not available")
def test_to_device(self):
"""Test device transfer"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
# Moving to same device should work (CPU to CPU)
qt_cpu = qt.to('cpu')
self.assertIsInstance(qt_cpu, QuantizedTensor)
self.assertEqual(qt_cpu.device.type, 'cpu')
self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu')
class TestTensorCoreFP8Layout(unittest.TestCase):
"""Test the TensorCoreFP8Layout implementation"""
def test_quantize(self):
"""Test quantization method"""
float_tensor = torch.randn(32, 64, dtype=torch.float32)
scale = torch.tensor(1.5)
qdata, layout_params = TensorCoreFP8Layout.quantize(
float_tensor,
scale=scale,
dtype=torch.float8_e4m3fn
)
self.assertEqual(qdata.dtype, torch.float8_e4m3fn)
self.assertEqual(qdata.shape, float_tensor.shape)
self.assertIn('scale', layout_params)
self.assertIn('orig_dtype', layout_params)
self.assertEqual(layout_params['orig_dtype'], torch.float32)
def test_dequantize(self):
"""Test dequantization method"""
float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0
scale = torch.tensor(1.0)
qdata, layout_params = TensorCoreFP8Layout.quantize(
float_tensor,
scale=scale,
dtype=torch.float8_e4m3fn
)
dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params)
# Should approximately match original
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
class TestFallbackMechanism(unittest.TestCase):
"""Test fallback for unsupported operations"""
def test_unsupported_op_dequantizes(self):
"""Test that unsupported operations fall back to dequantization"""
# Set seed for reproducibility
torch.manual_seed(42)
# Create quantized tensor
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
scale = torch.tensor(1.0)
a_q = QuantizedTensor.from_float(
a_fp32,
"TensorCoreFP8Layout",
scale=scale,
dtype=torch.float8_e4m3fn
)
# Call an operation that doesn't have a registered handler
# For example, torch.abs
result = torch.abs(a_q)
# Should work via fallback (dequantize → abs → return)
self.assertNotIsInstance(result, QuantizedTensor)
expected = torch.abs(a_fp32)
# FP8 introduces quantization error, so use loose tolerance
mean_error = (result - expected).abs().mean()
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
if __name__ == "__main__":
unittest.main()

View File

@@ -152,12 +152,12 @@ class TestExecution:
# Initialize server and client
#
@fixture(scope="class", autouse=True, params=[
# (use_lru, lru_size)
(False, 0),
(True, 0),
(True, 100),
{ "extra_args" : [], "should_cache_results" : True },
{ "extra_args" : ["--cache-lru", 0], "should_cache_results" : True },
{ "extra_args" : ["--cache-lru", 100], "should_cache_results" : True },
{ "extra_args" : ["--cache-none"], "should_cache_results" : False },
])
def _server(self, args_pytest, request):
def server(self, args_pytest, request):
# Start server
pargs = [
'python','main.py',
@@ -167,12 +167,10 @@ class TestExecution:
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
'--cpu',
]
use_lru, lru_size = request.param
if use_lru:
pargs += ['--cache-lru', str(lru_size)]
pargs += [ str(param) for param in request.param["extra_args"] ]
print("Running server with args:", pargs) # noqa: T201
p = subprocess.Popen(pargs)
yield
yield request.param
p.kill()
torch.cuda.empty_cache()
@@ -193,7 +191,7 @@ class TestExecution:
return comfy_client
@fixture(scope="class", autouse=True)
def shared_client(self, args_pytest, _server):
def shared_client(self, args_pytest, server):
client = self.start_client(args_pytest["listen"], args_pytest["port"])
yield client
del client
@@ -225,7 +223,7 @@ class TestExecution:
assert result.did_run(mask)
assert result.did_run(lazy_mix)
def test_full_cache(self, client: ComfyClient, builder: GraphBuilder):
def test_full_cache(self, client: ComfyClient, builder: GraphBuilder, server):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
@@ -237,9 +235,12 @@ class TestExecution:
client.run(g)
result2 = client.run(g)
for node_id, node in g.nodes.items():
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
if server["should_cache_results"]:
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
else:
assert result2.did_run(node), f"Node {node_id} was cached, but should have been run"
def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder):
def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder, server):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
@@ -251,8 +252,12 @@ class TestExecution:
client.run(g)
mask.inputs['value'] = 0.4
result2 = client.run(g)
assert not result2.did_run(input1), "Input1 should have been cached"
assert not result2.did_run(input2), "Input2 should have been cached"
if server["should_cache_results"]:
assert not result2.did_run(input1), "Input1 should have been cached"
assert not result2.did_run(input2), "Input2 should have been cached"
else:
assert result2.did_run(input1), "Input1 should have been rerun"
assert result2.did_run(input2), "Input2 should have been rerun"
def test_error(self, client: ComfyClient, builder: GraphBuilder):
g = builder
@@ -411,7 +416,7 @@ class TestExecution:
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
client.run(g)
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder, server):
g = builder
# Creating the nodes in this specific order previously caused a bug
save = g.node("SaveImage")
@@ -427,7 +432,10 @@ class TestExecution:
result3 = client.run(g)
result4 = client.run(g)
assert result1.did_run(is_changed), "is_changed should have been run"
assert not result2.did_run(is_changed), "is_changed should have been cached"
if server["should_cache_results"]:
assert not result2.did_run(is_changed), "is_changed should have been cached"
else:
assert result2.did_run(is_changed), "is_changed should have been re-run"
assert result3.did_run(is_changed), "is_changed should have been re-run"
assert result4.did_run(is_changed), "is_changed should not have been cached"
@@ -514,7 +522,7 @@ class TestExecution:
assert len(images2) == 1, "Should have 1 image"
# This tests that only constant outputs are used in the call to `IS_CHANGED`
def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder):
def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder, server):
g = builder
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5)
@@ -530,7 +538,11 @@ class TestExecution:
images = result.get_images(output)
assert len(images) == 1, "Should have 1 image"
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
assert not result.did_run(test_node), "The execution should have been cached"
if server["should_cache_results"]:
assert not result.did_run(test_node), "The execution should have been cached"
else:
assert result.did_run(test_node), "The execution should have been re-run"
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
# Warmup execution to ensure server is fully initialized