mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-15 04:30:02 +00:00
Compare commits
110 Commits
v0.3.67
...
worksplit-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4661d1db5a | ||
|
|
b326a544d5 | ||
|
|
d89dd5f0b0 | ||
|
|
8cbbf0be6c | ||
|
|
c2115a4bac | ||
|
|
bb44c2ecb9 | ||
|
|
efcd8280d6 | ||
|
|
9e9c129cd0 | ||
|
|
ac14ee68c0 | ||
|
|
2c8f485434 | ||
|
|
383f9b34cb | ||
|
|
b0741c7e5b | ||
|
|
1489399cb5 | ||
|
|
3677943fa5 | ||
|
|
cfb63bfcd7 | ||
|
|
962c3c832c | ||
|
|
6ea69369ce | ||
|
|
b4f559b34d | ||
|
|
df122a7dba | ||
|
|
67e906aa64 | ||
|
|
382f84a826 | ||
|
|
9cca36fa2b | ||
|
|
5d5024296d | ||
|
|
3b90a30178 | ||
|
|
3c4104652b | ||
|
|
9855baaab3 | ||
|
|
d53479a197 | ||
|
|
443a795850 | ||
|
|
431dec8e53 | ||
|
|
44e053c26d | ||
|
|
1ae98932f1 | ||
|
|
0336b0ace8 | ||
|
|
8ae25235ec | ||
|
|
9726eac475 | ||
|
|
272e8d42c1 | ||
|
|
6211d2be5a | ||
|
|
8be711715c | ||
|
|
b5cccf1325 | ||
|
|
2a54a904f4 | ||
|
|
ed6f92c975 | ||
|
|
adc66c0698 | ||
|
|
ccd5c01e5a | ||
|
|
2fa9affcc1 | ||
|
|
407a5a656f | ||
|
|
9ce9ff8ef8 | ||
|
|
63567c0ce8 | ||
|
|
a786ce5ead | ||
|
|
4879b47648 | ||
|
|
5ccec33c22 | ||
|
|
219d3cd0d0 | ||
|
|
c4ba399475 | ||
|
|
cc928a786d | ||
|
|
6e144b98c4 | ||
|
|
6dca17bd2d | ||
|
|
5080105c23 | ||
|
|
093914a247 | ||
|
|
605893d3cf | ||
|
|
048f4f0b3a | ||
|
|
d2504fb701 | ||
|
|
b03763bca6 | ||
|
|
476aa79b64 | ||
|
|
441cfd1a7a | ||
|
|
99a5c1068a | ||
|
|
02747cde7d | ||
|
|
0b3233b4e2 | ||
|
|
eda866bf51 | ||
|
|
e3298b84de | ||
|
|
c7feef9060 | ||
|
|
51af7fa1b4 | ||
|
|
46969c380a | ||
|
|
5db4277449 | ||
|
|
02a4d0ad7d | ||
|
|
ef137ac0b6 | ||
|
|
328d4f16a9 | ||
|
|
bdbcb85b8d | ||
|
|
6c9e94bae7 | ||
|
|
bfce723311 | ||
|
|
31f5458938 | ||
|
|
2145a202eb | ||
|
|
25818dc848 | ||
|
|
198953cd08 | ||
|
|
ec16ee2f39 | ||
|
|
d5088072fb | ||
|
|
8d4b50158e | ||
|
|
e88c6c03ff | ||
|
|
d3cf2b7b24 | ||
|
|
7448f02b7c | ||
|
|
871258aa72 | ||
|
|
66838ebd39 | ||
|
|
7333281698 | ||
|
|
3cd4c5cb0a | ||
|
|
11c6d56037 | ||
|
|
216fea15ee | ||
|
|
58bf8815c8 | ||
|
|
1b38f5bf57 | ||
|
|
2724ac4a60 | ||
|
|
f48f90e471 | ||
|
|
6463c39ce0 | ||
|
|
0a7e2ae787 | ||
|
|
03a97b604a | ||
|
|
4446c86052 | ||
|
|
8270ff312f | ||
|
|
db2d7ad9ba | ||
|
|
6620d86318 | ||
|
|
111fd0cadf | ||
|
|
776aa734e1 | ||
|
|
5a2ad032cb | ||
|
|
d44295ef71 | ||
|
|
bf21be066f | ||
|
|
72bbf49349 |
@@ -1,2 +0,0 @@
|
||||
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
|
||||
pause
|
||||
@@ -17,7 +17,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "130"
|
||||
default: "129"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@@ -29,7 +29,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
default: "6"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
@@ -197,12 +197,10 @@ comfy install
|
||||
|
||||
## Manual Install (Windows, Linux)
|
||||
|
||||
Python 3.14 will work if you comment out the `kornia` dependency in the requirements.txt file (breaks the canny node) but it is not recommended.
|
||||
Python 3.14 will work if you comment out the `kornia` dependency in the requirements.txt file (breaks the canny node) and install pytorch nightly but it is not recommended.
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
### Instructions:
|
||||
|
||||
Git clone this repo.
|
||||
|
||||
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
import os
|
||||
import folder_paths
|
||||
import glob
|
||||
from aiohttp import web
|
||||
import hashlib
|
||||
|
||||
|
||||
class Source:
|
||||
custom_node = "custom_node"
|
||||
|
||||
class SubgraphEntry(TypedDict):
|
||||
source: str
|
||||
"""
|
||||
Source of subgraph - custom_nodes vs templates.
|
||||
"""
|
||||
path: str
|
||||
"""
|
||||
Relative path of the subgraph file.
|
||||
For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
|
||||
"""
|
||||
name: str
|
||||
"""
|
||||
Name of subgraph file.
|
||||
"""
|
||||
info: CustomNodeSubgraphEntryInfo
|
||||
"""
|
||||
Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
|
||||
"""
|
||||
data: str
|
||||
|
||||
class CustomNodeSubgraphEntryInfo(TypedDict):
|
||||
node_pack: str
|
||||
"""Node pack name."""
|
||||
|
||||
class SubgraphManager:
|
||||
def __init__(self):
|
||||
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
|
||||
|
||||
async def load_entry_data(self, entry: SubgraphEntry):
|
||||
with open(entry['path'], 'r') as f:
|
||||
entry['data'] = f.read()
|
||||
return entry
|
||||
|
||||
async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
|
||||
if entry is None:
|
||||
return None
|
||||
entry = entry.copy()
|
||||
entry.pop('path', None)
|
||||
if remove_data:
|
||||
entry.pop('data', None)
|
||||
return entry
|
||||
|
||||
async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
|
||||
entries = entries.copy()
|
||||
for key in list(entries.keys()):
|
||||
entries[key] = await self.sanitize_entry(entries[key], remove_data)
|
||||
return entries
|
||||
|
||||
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
|
||||
# if not forced to reload and cached, return cache
|
||||
if not force_reload and self.cached_custom_node_subgraphs is not None:
|
||||
return self.cached_custom_node_subgraphs
|
||||
# Load subgraphs from custom nodes
|
||||
subfolder = "subgraphs"
|
||||
subgraphs_dict: dict[SubgraphEntry] = {}
|
||||
|
||||
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
||||
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
|
||||
matched_files = glob.glob(pattern)
|
||||
for file in matched_files:
|
||||
# replace backslashes with forward slashes
|
||||
file = file.replace('\\', '/')
|
||||
info: CustomNodeSubgraphEntryInfo = {
|
||||
"node_pack": "custom_nodes." + file.split('/')[-3]
|
||||
}
|
||||
source = Source.custom_node
|
||||
# hash source + path to make sure id will be as unique as possible, but
|
||||
# reproducible across backend reloads
|
||||
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
|
||||
entry: SubgraphEntry = {
|
||||
"source": Source.custom_node,
|
||||
"name": os.path.splitext(os.path.basename(file))[0],
|
||||
"path": file,
|
||||
"info": info,
|
||||
}
|
||||
subgraphs_dict[id] = entry
|
||||
self.cached_custom_node_subgraphs = subgraphs_dict
|
||||
return subgraphs_dict
|
||||
|
||||
async def get_custom_node_subgraph(self, id: str, loadedModules):
|
||||
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
|
||||
entry: SubgraphEntry = subgraphs.get(id, None)
|
||||
if entry is not None and entry.get('data', None) is None:
|
||||
await self.load_entry_data(entry)
|
||||
return entry
|
||||
|
||||
def add_routes(self, routes, loadedModules):
|
||||
@routes.get("/global_subgraphs")
|
||||
async def get_global_subgraphs(request):
|
||||
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
|
||||
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
|
||||
# that's the reasoning for the current implementation
|
||||
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
|
||||
|
||||
@routes.get("/global_subgraphs/{id}")
|
||||
async def get_global_subgraph(request):
|
||||
id = request.match_info.get("id", None)
|
||||
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
|
||||
return web.json_response(await self.sanitize_entry(subgraph))
|
||||
@@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
|
||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
|
||||
parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use. All other devices will not be visible.")
|
||||
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
|
||||
cm_group = parser.add_mutually_exclusive_group()
|
||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||
|
||||
@@ -15,13 +15,14 @@
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from enum import Enum
|
||||
import math
|
||||
import os
|
||||
import logging
|
||||
import copy
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_detection
|
||||
@@ -38,7 +39,7 @@ import comfy.ldm.hydit.controlnet
|
||||
import comfy.ldm.flux.controlnet
|
||||
import comfy.ldm.qwen_image.controlnet
|
||||
import comfy.cldm.dit_embedder
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Union
|
||||
if TYPE_CHECKING:
|
||||
from comfy.hooks import HookGroup
|
||||
|
||||
@@ -64,6 +65,18 @@ class StrengthType(Enum):
|
||||
CONSTANT = 1
|
||||
LINEAR_UP = 2
|
||||
|
||||
class ControlIsolation:
|
||||
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
|
||||
def __init__(self, control: ControlBase):
|
||||
self.control = control
|
||||
self.orig_previous_controlnet = control.previous_controlnet
|
||||
|
||||
def __enter__(self):
|
||||
self.control.previous_controlnet = None
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.control.previous_controlnet = self.orig_previous_controlnet
|
||||
|
||||
class ControlBase:
|
||||
def __init__(self):
|
||||
self.cond_hint_original = None
|
||||
@@ -77,7 +90,7 @@ class ControlBase:
|
||||
self.compression_ratio = 8
|
||||
self.upscale_algorithm = 'nearest-exact'
|
||||
self.extra_args = {}
|
||||
self.previous_controlnet = None
|
||||
self.previous_controlnet: Union[ControlBase, None] = None
|
||||
self.extra_conds = []
|
||||
self.strength_type = StrengthType.CONSTANT
|
||||
self.concat_mask = False
|
||||
@@ -85,6 +98,7 @@ class ControlBase:
|
||||
self.extra_concat = None
|
||||
self.extra_hooks: HookGroup = None
|
||||
self.preprocess_image = lambda a: a
|
||||
self.multigpu_clones: dict[torch.device, ControlBase] = {}
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||
self.cond_hint_original = cond_hint
|
||||
@@ -111,17 +125,38 @@ class ControlBase:
|
||||
def cleanup(self):
|
||||
if self.previous_controlnet is not None:
|
||||
self.previous_controlnet.cleanup()
|
||||
|
||||
for device_cnet in self.multigpu_clones.values():
|
||||
with ControlIsolation(device_cnet):
|
||||
device_cnet.cleanup()
|
||||
self.cond_hint = None
|
||||
self.extra_concat = None
|
||||
self.timestep_range = None
|
||||
|
||||
def get_models(self):
|
||||
out = []
|
||||
for device_cnet in self.multigpu_clones.values():
|
||||
out += device_cnet.get_models_only_self()
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_models()
|
||||
return out
|
||||
|
||||
def get_models_only_self(self):
|
||||
'Calls get_models, but temporarily sets previous_controlnet to None.'
|
||||
with ControlIsolation(self):
|
||||
return self.get_models()
|
||||
|
||||
def get_instance_for_device(self, device):
|
||||
'Returns instance of this Control object intended for selected device.'
|
||||
return self.multigpu_clones.get(device, self)
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
'''
|
||||
Create deep clone of Control object where model(s) is set to other devices.
|
||||
|
||||
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
|
||||
'''
|
||||
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
|
||||
|
||||
def get_extra_hooks(self):
|
||||
out = []
|
||||
if self.extra_hooks is not None:
|
||||
@@ -130,7 +165,7 @@ class ControlBase:
|
||||
out += self.previous_controlnet.get_extra_hooks()
|
||||
return out
|
||||
|
||||
def copy_to(self, c):
|
||||
def copy_to(self, c: ControlBase):
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
c.strength = self.strength
|
||||
c.timestep_percent_range = self.timestep_percent_range
|
||||
@@ -284,6 +319,14 @@ class ControlNet(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
c = self.copy()
|
||||
c.control_model = copy.deepcopy(c.control_model)
|
||||
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
if autoregister:
|
||||
self.multigpu_clones[load_device] = c
|
||||
return c
|
||||
|
||||
def get_models(self):
|
||||
out = super().get_models()
|
||||
out.append(self.control_model_wrapped)
|
||||
@@ -829,6 +872,14 @@ class T2IAdapter(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
c = self.copy()
|
||||
c.t2i_model = copy.deepcopy(c.t2i_model)
|
||||
c.device = load_device
|
||||
if autoregister:
|
||||
self.multigpu_clones[load_device] = c
|
||||
return c
|
||||
|
||||
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||
compression_ratio = 8
|
||||
upscale_algorithm = 'nearest-exact'
|
||||
|
||||
@@ -189,15 +189,15 @@ class ChromaRadiance(Chroma):
|
||||
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
|
||||
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
|
||||
|
||||
# Reshape for per-patch processing
|
||||
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
|
||||
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
||||
|
||||
if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
|
||||
# Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
|
||||
# the tile size.
|
||||
img_dct = self.forward_tiled_nerf(nerf_hidden, nerf_pixels, B, C, num_patches, patch_size, params)
|
||||
img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params)
|
||||
else:
|
||||
# Reshape for per-patch processing
|
||||
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
|
||||
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
||||
|
||||
# Get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct = self.nerf_image_embedder(nerf_pixels)
|
||||
|
||||
@@ -240,8 +240,17 @@ class ChromaRadiance(Chroma):
|
||||
end = min(i + tile_size, num_patches)
|
||||
|
||||
# Slice the current tile from the input tensors
|
||||
nerf_hidden_tile = nerf_hidden[i * batch:end * batch]
|
||||
nerf_pixels_tile = nerf_pixels[i * batch:end * batch]
|
||||
nerf_hidden_tile = nerf_hidden[:, i:end, :]
|
||||
nerf_pixels_tile = nerf_pixels[:, i:end, :]
|
||||
|
||||
# Get the actual number of patches in this tile (can be smaller for the last tile)
|
||||
num_patches_tile = nerf_hidden_tile.shape[1]
|
||||
|
||||
# Reshape the tile for per-patch processing
|
||||
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
|
||||
nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size)
|
||||
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
|
||||
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
|
||||
|
||||
# get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
|
||||
|
||||
@@ -197,14 +197,8 @@ class BaseModel(torch.nn.Module):
|
||||
extra_conds[o] = extra
|
||||
|
||||
t = self.process_timestep(t, x=x, **extra_conds)
|
||||
if "latent_shapes" in extra_conds:
|
||||
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
|
||||
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
|
||||
if len(model_output) > 1 and not torch.is_tensor(model_output):
|
||||
model_output, _ = utils.pack_latents(model_output)
|
||||
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output.float(), x)
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
|
||||
def process_timestep(self, timestep, **kwargs):
|
||||
return timestep
|
||||
|
||||
@@ -213,7 +213,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["nerf_mlp_ratio"] = 4
|
||||
dit_config["nerf_depth"] = 4
|
||||
dit_config["nerf_max_freqs"] = 8
|
||||
dit_config["nerf_tile_size"] = 512
|
||||
dit_config["nerf_tile_size"] = 32
|
||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||
else:
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import psutil
|
||||
import logging
|
||||
@@ -27,6 +28,10 @@ import platform
|
||||
import weakref
|
||||
import gc
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||
@@ -89,7 +94,6 @@ if args.deterministic:
|
||||
|
||||
directml_enabled = False
|
||||
if args.directml is not None:
|
||||
logging.warning("WARNING: torch-directml barely works, is very slow, has not been updated in over 1 year and might be removed soon, please don't use it, there are better options.")
|
||||
import torch_directml
|
||||
directml_enabled = True
|
||||
device_index = args.directml
|
||||
@@ -187,6 +191,25 @@ def get_torch_device():
|
||||
else:
|
||||
return torch.device(torch.cuda.current_device())
|
||||
|
||||
def get_all_torch_devices(exclude_current=False):
|
||||
global cpu_state
|
||||
devices = []
|
||||
if cpu_state == CPUState.GPU:
|
||||
if is_nvidia():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
devices.append(torch.device(i))
|
||||
elif is_intel_xpu():
|
||||
for i in range(torch.xpu.device_count()):
|
||||
devices.append(torch.device(i))
|
||||
elif is_ascend_npu():
|
||||
for i in range(torch.npu.device_count()):
|
||||
devices.append(torch.device(i))
|
||||
else:
|
||||
devices.append(get_torch_device())
|
||||
if exclude_current:
|
||||
devices.remove(get_torch_device())
|
||||
return devices
|
||||
|
||||
def get_total_memory(dev=None, torch_total_too=False):
|
||||
global directml_enabled
|
||||
if dev is None:
|
||||
@@ -331,21 +354,14 @@ except:
|
||||
|
||||
|
||||
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
||||
|
||||
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
|
||||
|
||||
try:
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
||||
|
||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||
try:
|
||||
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
||||
except:
|
||||
rocm_version = (6, -1)
|
||||
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
logging.info("AMD arch: {}".format(arch))
|
||||
logging.info("ROCm version: {}".format(rocm_version))
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
@@ -379,9 +395,6 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
try:
|
||||
if torch_version_numeric >= (2, 5):
|
||||
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
||||
@@ -444,9 +457,13 @@ try:
|
||||
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
||||
except:
|
||||
logging.warning("Could not pick default device.")
|
||||
try:
|
||||
for device in get_all_torch_devices(exclude_current=True):
|
||||
logging.info("Device: {}".format(get_torch_device_name(device)))
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
current_loaded_models = []
|
||||
current_loaded_models: list[LoadedModel] = []
|
||||
|
||||
def module_size(module):
|
||||
module_mem = 0
|
||||
@@ -457,7 +474,7 @@ def module_size(module):
|
||||
return module_mem
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
def __init__(self, model: ModelPatcher):
|
||||
self._set_model(model)
|
||||
self.device = model.load_device
|
||||
self.real_model = None
|
||||
@@ -465,7 +482,7 @@ class LoadedModel:
|
||||
self.model_finalizer = None
|
||||
self._patcher_finalizer = None
|
||||
|
||||
def _set_model(self, model):
|
||||
def _set_model(self, model: ModelPatcher):
|
||||
self._model = weakref.ref(model)
|
||||
if model.parent is not None:
|
||||
self._parent_model = weakref.ref(model.parent)
|
||||
@@ -999,6 +1016,12 @@ def device_supports_non_blocking(device):
|
||||
return False
|
||||
return True
|
||||
|
||||
def device_should_use_non_blocking(device):
|
||||
if not device_supports_non_blocking(device):
|
||||
return False
|
||||
return False
|
||||
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
|
||||
|
||||
def force_channels_last():
|
||||
if args.force_channels_last:
|
||||
return True
|
||||
@@ -1332,7 +1355,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(device).gcnArchName
|
||||
if any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH): # RDNA2 and older don't support bf16
|
||||
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
|
||||
if manual_cast:
|
||||
return True
|
||||
return False
|
||||
@@ -1401,8 +1424,34 @@ def soft_empty_cache(force=False):
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
for device in get_all_torch_devices():
|
||||
free_memory(1e30, device)
|
||||
|
||||
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
|
||||
'Unload only model and its clones - primarily for multigpu cloning purposes.'
|
||||
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
|
||||
additional_models = []
|
||||
if unload_additional_models:
|
||||
additional_models = model.get_nested_additional_models()
|
||||
keep_loaded = []
|
||||
for loaded_model in initial_keep_loaded:
|
||||
if loaded_model.model is not None:
|
||||
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||
continue
|
||||
# check additional models if they are a match
|
||||
skip = False
|
||||
for add_model in additional_models:
|
||||
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||
skip = True
|
||||
break
|
||||
if skip:
|
||||
continue
|
||||
keep_loaded.append(loaded_model)
|
||||
if not all_devices:
|
||||
free_memory(1e30, get_torch_device(), keep_loaded)
|
||||
else:
|
||||
for device in get_all_torch_devices():
|
||||
free_memory(1e30, device, keep_loaded)
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
@@ -87,12 +87,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
||||
def create_model_options_clone(orig_model_options: dict):
|
||||
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
||||
|
||||
def create_hook_patches_clone(orig_hook_patches):
|
||||
def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
|
||||
new_hook_patches = {}
|
||||
for hook_ref in orig_hook_patches:
|
||||
new_hook_patches[hook_ref] = {}
|
||||
for k in orig_hook_patches[hook_ref]:
|
||||
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
|
||||
if copy_tuples:
|
||||
for i in range(len(new_hook_patches[hook_ref][k])):
|
||||
new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
|
||||
return new_hook_patches
|
||||
|
||||
def wipe_lowvram_weight(m):
|
||||
@@ -257,6 +260,9 @@ class ModelPatcher:
|
||||
self.is_clip = False
|
||||
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
||||
|
||||
self.is_multigpu_base_clone = False
|
||||
self.clone_base_uuid = uuid.uuid4()
|
||||
|
||||
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
@@ -335,18 +341,92 @@ class ModelPatcher:
|
||||
n.is_clip = self.is_clip
|
||||
n.hook_mode = self.hook_mode
|
||||
|
||||
n.is_multigpu_base_clone = self.is_multigpu_base_clone
|
||||
n.clone_base_uuid = self.clone_base_uuid
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||
callback(self, n)
|
||||
return n
|
||||
|
||||
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
|
||||
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
|
||||
comfy.model_management.unload_model_and_clones(self)
|
||||
n = self.clone()
|
||||
# set load device, if present
|
||||
if new_load_device is not None:
|
||||
n.load_device = new_load_device
|
||||
# unlike for normal clone, backup dicts that shared same ref should not;
|
||||
# otherwise, patchers that have deep copies of base models will erroneously influence each other.
|
||||
n.backup = copy.deepcopy(n.backup)
|
||||
n.object_patches_backup = copy.deepcopy(n.object_patches_backup)
|
||||
n.hook_backup = copy.deepcopy(n.hook_backup)
|
||||
n.model = copy.deepcopy(n.model)
|
||||
# multigpu clone should not have multigpu additional_models entry
|
||||
n.remove_additional_models("multigpu")
|
||||
# multigpu_clone all stored additional_models; make sure circular references are properly handled
|
||||
if models_cache is None:
|
||||
models_cache = {}
|
||||
for key, model_list in n.additional_models.items():
|
||||
for i in range(len(model_list)):
|
||||
add_model = n.additional_models[key][i]
|
||||
if add_model.clone_base_uuid not in models_cache:
|
||||
models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
|
||||
n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
|
||||
callback(self, n)
|
||||
return n
|
||||
|
||||
def match_multigpu_clones(self):
|
||||
multigpu_models = self.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) > 0:
|
||||
new_multigpu_models = []
|
||||
for mm in multigpu_models:
|
||||
# clone main model, but bring over relevant props from existing multigpu clone
|
||||
n = self.clone()
|
||||
n.load_device = mm.load_device
|
||||
n.backup = mm.backup
|
||||
n.object_patches_backup = mm.object_patches_backup
|
||||
n.hook_backup = mm.hook_backup
|
||||
n.model = mm.model
|
||||
n.is_multigpu_base_clone = mm.is_multigpu_base_clone
|
||||
n.remove_additional_models("multigpu")
|
||||
orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
|
||||
n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
|
||||
# figure out which additional models are not present in multigpu clone
|
||||
models_cache = {}
|
||||
for mm_add_model in mm.get_additional_models():
|
||||
models_cache[mm_add_model.clone_base_uuid] = mm_add_model
|
||||
remove_models_uuids = set(list(models_cache.keys()))
|
||||
for key, model_list in orig_additional_models.items():
|
||||
for orig_add_model in model_list:
|
||||
if orig_add_model.clone_base_uuid not in models_cache:
|
||||
models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
|
||||
existing_list = n.get_additional_models_with_key(key)
|
||||
existing_list.append(models_cache[orig_add_model.clone_base_uuid])
|
||||
n.set_additional_models(key, existing_list)
|
||||
if orig_add_model.clone_base_uuid in remove_models_uuids:
|
||||
remove_models_uuids.remove(orig_add_model.clone_base_uuid)
|
||||
# remove duplicate additional models
|
||||
for key, model_list in n.additional_models.items():
|
||||
new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
|
||||
n.set_additional_models(key, new_model_list)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
|
||||
callback(self, n)
|
||||
new_multigpu_models.append(n)
|
||||
self.set_additional_models("multigpu", new_multigpu_models)
|
||||
|
||||
def is_clone(self, other):
|
||||
if hasattr(other, 'model') and self.model is other.model:
|
||||
return True
|
||||
return False
|
||||
|
||||
def clone_has_same_weights(self, clone: 'ModelPatcher'):
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
|
||||
if allow_multigpu:
|
||||
if self.clone_base_uuid != clone.clone_base_uuid:
|
||||
return False
|
||||
else:
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
|
||||
if self.current_hooks != clone.current_hooks:
|
||||
return False
|
||||
@@ -983,7 +1063,7 @@ class ModelPatcher:
|
||||
return self.additional_models.get(key, [])
|
||||
|
||||
def get_additional_models(self):
|
||||
all_models = []
|
||||
all_models: list[ModelPatcher] = []
|
||||
for models in self.additional_models.values():
|
||||
all_models.extend(models)
|
||||
return all_models
|
||||
@@ -1037,9 +1117,13 @@ class ModelPatcher:
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
||||
callback(self)
|
||||
|
||||
def prepare_state(self, timestep):
|
||||
def prepare_state(self, timestep, model_options, ignore_multigpu=False):
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
||||
callback(self, timestep)
|
||||
callback(self, timestep, model_options, ignore_multigpu)
|
||||
if not ignore_multigpu and "multigpu_clones" in model_options:
|
||||
for p in model_options["multigpu_clones"].values():
|
||||
p: ModelPatcher
|
||||
p.prepare_state(timestep, model_options, ignore_multigpu=True)
|
||||
|
||||
def restore_hook_patches(self):
|
||||
if self.hook_patches_backup is not None:
|
||||
@@ -1052,12 +1136,18 @@ class ModelPatcher:
|
||||
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
|
||||
curr_t = t[0]
|
||||
reset_current_hooks = False
|
||||
multigpu_kf_changed_cache = None
|
||||
transformer_options = model_options.get("transformer_options", {})
|
||||
for hook in hook_group.hooks:
|
||||
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
|
||||
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
||||
# this will cause the weights to be recalculated when sampling
|
||||
if changed:
|
||||
# cache changed for multigpu usage
|
||||
if "multigpu_clones" in model_options:
|
||||
if multigpu_kf_changed_cache is None:
|
||||
multigpu_kf_changed_cache = []
|
||||
multigpu_kf_changed_cache.append(hook)
|
||||
# reset current_hooks if contains hook that changed
|
||||
if self.current_hooks is not None:
|
||||
for current_hook in self.current_hooks.hooks:
|
||||
@@ -1069,6 +1159,28 @@ class ModelPatcher:
|
||||
self.cached_hook_patches.pop(cached_group)
|
||||
if reset_current_hooks:
|
||||
self.patch_hooks(None)
|
||||
if "multigpu_clones" in model_options:
|
||||
for p in model_options["multigpu_clones"].values():
|
||||
p: ModelPatcher
|
||||
p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
|
||||
|
||||
def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
|
||||
'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
|
||||
if kf_changed_cache is None:
|
||||
return
|
||||
reset_current_hooks = False
|
||||
# reset current_hooks if contains hook that changed
|
||||
for hook in kf_changed_cache:
|
||||
if self.current_hooks is not None:
|
||||
for current_hook in self.current_hooks.hooks:
|
||||
if current_hook == hook:
|
||||
reset_current_hooks = True
|
||||
break
|
||||
for cached_group in list(self.cached_hook_patches.keys()):
|
||||
if cached_group.contains(hook):
|
||||
self.cached_hook_patches.pop(cached_group)
|
||||
if reset_current_hooks:
|
||||
self.patch_hooks(None)
|
||||
|
||||
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
||||
registered: comfy.hooks.HookGroup = None):
|
||||
|
||||
167
comfy/multigpu.py
Normal file
167
comfy/multigpu.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
import comfy.utils
|
||||
import comfy.patcher_extension
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class GPUOptions:
|
||||
def __init__(self, device_index: int, relative_speed: float):
|
||||
self.device_index = device_index
|
||||
self.relative_speed = relative_speed
|
||||
|
||||
def clone(self):
|
||||
return GPUOptions(self.device_index, self.relative_speed)
|
||||
|
||||
def create_dict(self):
|
||||
return {
|
||||
"relative_speed": self.relative_speed
|
||||
}
|
||||
|
||||
class GPUOptionsGroup:
|
||||
def __init__(self):
|
||||
self.options: dict[int, GPUOptions] = {}
|
||||
|
||||
def add(self, info: GPUOptions):
|
||||
self.options[info.device_index] = info
|
||||
|
||||
def clone(self):
|
||||
c = GPUOptionsGroup()
|
||||
for opt in self.options.values():
|
||||
c.add(opt)
|
||||
return c
|
||||
|
||||
def register(self, model: ModelPatcher):
|
||||
opts_dict = {}
|
||||
# get devices that are valid for this model
|
||||
devices: list[torch.device] = [model.load_device]
|
||||
for extra_model in model.get_additional_models_with_key("multigpu"):
|
||||
extra_model: ModelPatcher
|
||||
devices.append(extra_model.load_device)
|
||||
# create dictionary with actual device mapped to its GPUOptions
|
||||
device_opts_list: list[GPUOptions] = []
|
||||
for device in devices:
|
||||
device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
|
||||
opts_dict[device] = device_opts.create_dict()
|
||||
device_opts_list.append(device_opts)
|
||||
# make relative_speed relative to 1.0
|
||||
min_speed = min([x.relative_speed for x in device_opts_list])
|
||||
for value in opts_dict.values():
|
||||
value['relative_speed'] /= min_speed
|
||||
model.model_options['multigpu_options'] = opts_dict
|
||||
|
||||
|
||||
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
|
||||
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
|
||||
model = model.clone()
|
||||
# check if multigpu is already prepared - get the load devices from them if possible to exclude
|
||||
skip_devices = set()
|
||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) > 0:
|
||||
for mm in multigpu_models:
|
||||
skip_devices.add(mm.load_device)
|
||||
skip_devices = list(skip_devices)
|
||||
|
||||
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
|
||||
limit_extra_devices = full_extra_devices[:max_gpus-1]
|
||||
extra_devices = limit_extra_devices.copy()
|
||||
# exclude skipped devices
|
||||
for skip in skip_devices:
|
||||
if skip in extra_devices:
|
||||
extra_devices.remove(skip)
|
||||
# create new deepclones
|
||||
if len(extra_devices) > 0:
|
||||
for device in extra_devices:
|
||||
device_patcher = None
|
||||
if reuse_loaded:
|
||||
# check if there are any ModelPatchers currently loaded that could be referenced here after a clone
|
||||
loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
|
||||
for lm in loaded_models:
|
||||
if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device:
|
||||
device_patcher = lm.clone()
|
||||
logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}")
|
||||
break
|
||||
if device_patcher is None:
|
||||
device_patcher = model.deepclone_multigpu(new_load_device=device)
|
||||
device_patcher.is_multigpu_base_clone = True
|
||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
multigpu_models.append(device_patcher)
|
||||
model.set_additional_models("multigpu", multigpu_models)
|
||||
model.match_multigpu_clones()
|
||||
if gpu_options is None:
|
||||
gpu_options = GPUOptionsGroup()
|
||||
gpu_options.register(model)
|
||||
else:
|
||||
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
|
||||
# TODO: only keep model clones that don't go 'past' the intended max_gpu count
|
||||
# multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
# new_multigpu_models = []
|
||||
# for m in multigpu_models:
|
||||
# if m.load_device in limit_extra_devices:
|
||||
# new_multigpu_models.append(m)
|
||||
# model.set_additional_models("multigpu", new_multigpu_models)
|
||||
# persist skip_devices for use in sampling code
|
||||
# if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options:
|
||||
# model.model_options["multigpu_skip_devices"] = skip_devices
|
||||
return model
|
||||
|
||||
|
||||
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
|
||||
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
|
||||
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
|
||||
opts_dict = model_options['multigpu_options']
|
||||
devices = list(model_options['multigpu_clones'].keys())
|
||||
speed_per_device = []
|
||||
work_per_device = []
|
||||
# get sum of each device's relative_speed
|
||||
total_speed = 0.0
|
||||
for opts in opts_dict.values():
|
||||
total_speed += opts['relative_speed']
|
||||
# get relative work for each device;
|
||||
# obtained by w = (W*r)/R
|
||||
for device in devices:
|
||||
relative_speed = opts_dict[device]['relative_speed']
|
||||
relative_work = (total_work*relative_speed) / total_speed
|
||||
speed_per_device.append(relative_speed)
|
||||
work_per_device.append(relative_work)
|
||||
# relative work must be expressed in whole numbers, but likely is a decimal;
|
||||
# perform rounding while maintaining total sum equal to total work (sum of relative works)
|
||||
work_per_device = round_preserved(work_per_device)
|
||||
dict_work_per_device = {}
|
||||
for device, relative_work in zip(devices, work_per_device):
|
||||
dict_work_per_device[device] = relative_work
|
||||
if not return_idle_time:
|
||||
return LoadBalance(dict_work_per_device, None)
|
||||
# divide relative work by relative speed to get estimated completion time of said work by each device;
|
||||
# time here is relative and does not correspond to real-world units
|
||||
completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
|
||||
# calculate relative time spent by the devices waiting on each other after their work is completed
|
||||
idle_time = abs(min(completion_time) - max(completion_time))
|
||||
# if need to compare work idle time, need to normalize to a common total work
|
||||
if work_normalized:
|
||||
idle_time *= (work_normalized/total_work)
|
||||
|
||||
return LoadBalance(dict_work_per_device, idle_time)
|
||||
|
||||
def round_preserved(values: list[float]):
|
||||
'Round all values in a list, preserving the combined sum of values.'
|
||||
# get floor of values; casting to int does it too
|
||||
floored = [int(x) for x in values]
|
||||
total_floored = sum(floored)
|
||||
# get remainder to distribute
|
||||
remainder = round(sum(values)) - total_floored
|
||||
# pair values with fractional portions
|
||||
fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
|
||||
# sort by fractional part in descending order
|
||||
fractional.sort(key=lambda x: x[1], reverse=True)
|
||||
# distribute the remainder
|
||||
for i in range(remainder):
|
||||
index = fractional[i][0]
|
||||
floored[index] += 1
|
||||
return floored
|
||||
@@ -1,91 +0,0 @@
|
||||
import torch
|
||||
|
||||
class NestedTensor:
|
||||
def __init__(self, tensors):
|
||||
self.tensors = list(tensors)
|
||||
self.is_nested = True
|
||||
|
||||
def _copy(self):
|
||||
return NestedTensor(self.tensors)
|
||||
|
||||
def apply_operation(self, other, operation):
|
||||
o = self._copy()
|
||||
if isinstance(other, NestedTensor):
|
||||
for i, t in enumerate(o.tensors):
|
||||
o.tensors[i] = operation(t, other.tensors[i])
|
||||
else:
|
||||
for i, t in enumerate(o.tensors):
|
||||
o.tensors[i] = operation(t, other)
|
||||
return o
|
||||
|
||||
def __add__(self, b):
|
||||
return self.apply_operation(b, lambda x, y: x + y)
|
||||
|
||||
def __sub__(self, b):
|
||||
return self.apply_operation(b, lambda x, y: x - y)
|
||||
|
||||
def __mul__(self, b):
|
||||
return self.apply_operation(b, lambda x, y: x * y)
|
||||
|
||||
# def __itruediv__(self, b):
|
||||
# return self.apply_operation(b, lambda x, y: x / y)
|
||||
|
||||
def __truediv__(self, b):
|
||||
return self.apply_operation(b, lambda x, y: x / y)
|
||||
|
||||
def __getitem__(self, *args, **kwargs):
|
||||
return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs))
|
||||
|
||||
def unbind(self):
|
||||
return self.tensors
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
o = self._copy()
|
||||
for i, t in enumerate(o.tensors):
|
||||
o.tensors[i] = t.to(*args, **kwargs)
|
||||
return o
|
||||
|
||||
def new_ones(self, *args, **kwargs):
|
||||
return self.tensors[0].new_ones(*args, **kwargs)
|
||||
|
||||
def float(self):
|
||||
return self.to(dtype=torch.float)
|
||||
|
||||
def chunk(self, *args, **kwargs):
|
||||
return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs))
|
||||
|
||||
def size(self):
|
||||
return self.tensors[0].size()
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.tensors[0].shape
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
dims = 0
|
||||
for t in self.tensors:
|
||||
dims = max(t.ndim, dims)
|
||||
return dims
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.tensors[0].device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.tensors[0].dtype
|
||||
|
||||
@property
|
||||
def layout(self):
|
||||
return self.tensors[0].layout
|
||||
|
||||
|
||||
def cat_nested(tensors, *args, **kwargs):
|
||||
cated_tensors = []
|
||||
for i in range(len(tensors[0].tensors)):
|
||||
tens = []
|
||||
for j in range(len(tensors)):
|
||||
tens.append(tensors[j].tensors[i])
|
||||
cated_tensors.append(torch.cat(tens, *args, **kwargs))
|
||||
return NestedTensor(cated_tensors)
|
||||
26
comfy/ops.py
26
comfy/ops.py
@@ -25,9 +25,6 @@ import comfy.rmsnorm
|
||||
import contextlib
|
||||
|
||||
def run_every_op():
|
||||
if torch.compiler.is_compiling():
|
||||
return
|
||||
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
@@ -55,22 +52,14 @@ try:
|
||||
except (ModuleNotFoundError, TypeError):
|
||||
logging.warning("Could not set sdpa backend priority.")
|
||||
|
||||
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False
|
||||
try:
|
||||
if comfy.model_management.is_nvidia():
|
||||
if torch.backends.cudnn.version() >= 91002 and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10):
|
||||
#TODO: change upper bound version once it's fixed'
|
||||
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True
|
||||
logging.info("working around nvidia conv3d memory bug.")
|
||||
except:
|
||||
pass
|
||||
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
@torch.compiler.disable()
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
@@ -162,15 +151,6 @@ class disable_weight_init:
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _conv_forward(self, input, weight, bias, *args, **kwargs):
|
||||
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
|
||||
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
|
||||
if bias is not None:
|
||||
out += bias.reshape((1, -1) + (1,) * (out.ndim - 2))
|
||||
return out
|
||||
else:
|
||||
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import Callable
|
||||
|
||||
class CallbacksMP:
|
||||
ON_CLONE = "on_clone"
|
||||
ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
|
||||
ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
|
||||
ON_LOAD = "on_load_after"
|
||||
ON_DETACH = "on_detach_after"
|
||||
ON_CLEANUP = "on_cleanup"
|
||||
|
||||
@@ -4,9 +4,13 @@ import comfy.samplers
|
||||
import comfy.utils
|
||||
import numpy as np
|
||||
import logging
|
||||
import comfy.nested_tensor
|
||||
|
||||
def prepare_noise_inner(latent_image, generator, noise_inds=None):
|
||||
def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
"""
|
||||
creates random noise given a latent image and a seed.
|
||||
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
||||
"""
|
||||
generator = torch.manual_seed(seed)
|
||||
if noise_inds is None:
|
||||
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
|
||||
@@ -17,29 +21,10 @@ def prepare_noise_inner(latent_image, generator, noise_inds=None):
|
||||
if i in unique_inds:
|
||||
noises.append(noise)
|
||||
noises = [noises[i] for i in inverse]
|
||||
return torch.cat(noises, axis=0)
|
||||
|
||||
def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
"""
|
||||
creates random noise given a latent image and a seed.
|
||||
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
||||
"""
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
if latent_image.is_nested:
|
||||
tensors = latent_image.unbind()
|
||||
noises = []
|
||||
for t in tensors:
|
||||
noises.append(prepare_noise_inner(t, generator, noise_inds))
|
||||
noises = comfy.nested_tensor.NestedTensor(noises)
|
||||
else:
|
||||
noises = prepare_noise_inner(latent_image, generator, noise_inds)
|
||||
|
||||
noises = torch.cat(noises, axis=0)
|
||||
return noises
|
||||
|
||||
def fix_empty_latent_channels(model, latent_image):
|
||||
if latent_image.is_nested:
|
||||
return latent_image
|
||||
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
|
||||
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
import uuid
|
||||
import math
|
||||
import collections
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import comfy.model_patcher
|
||||
import comfy.utils
|
||||
import comfy.hooks
|
||||
import comfy.patcher_extension
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.controlnet import ControlBase
|
||||
|
||||
def prepare_mask(noise_mask, shape, device):
|
||||
@@ -106,6 +107,47 @@ def cleanup_additional_models(models):
|
||||
if hasattr(m, 'cleanup'):
|
||||
m.cleanup()
|
||||
|
||||
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
|
||||
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
|
||||
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) == 0:
|
||||
return
|
||||
extra_devices = [x.load_device for x in multigpu_models]
|
||||
# handle controlnets
|
||||
controlnets: set[ControlBase] = set()
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
if 'control' in kk:
|
||||
controlnets.add(kk['control'])
|
||||
if len(controlnets) > 0:
|
||||
# first, unload all controlnet clones
|
||||
for cnet in list(controlnets):
|
||||
cnet_models = cnet.get_models()
|
||||
for cm in cnet_models:
|
||||
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
|
||||
|
||||
# next, make sure each controlnet has a deepclone for all relevant devices
|
||||
for cnet in controlnets:
|
||||
curr_cnet = cnet
|
||||
while curr_cnet is not None:
|
||||
for device in extra_devices:
|
||||
if device not in curr_cnet.multigpu_clones:
|
||||
curr_cnet.deepclone_multigpu(device, autoregister=True)
|
||||
curr_cnet = curr_cnet.previous_controlnet
|
||||
# since all device clones are now present, recreate the linked list for cloned cnets per device
|
||||
for cnet in controlnets:
|
||||
curr_cnet = cnet
|
||||
while curr_cnet is not None:
|
||||
prev_cnet = curr_cnet.previous_controlnet
|
||||
for device in extra_devices:
|
||||
device_cnet = curr_cnet.get_instance_for_device(device)
|
||||
prev_device_cnet = None
|
||||
if prev_cnet is not None:
|
||||
prev_device_cnet = prev_cnet.get_instance_for_device(device)
|
||||
device_cnet.set_previous_controlnet(prev_device_cnet)
|
||||
curr_cnet = prev_cnet
|
||||
# potentially handle gligen - since not widely used, ignored for now
|
||||
|
||||
def estimate_memory(model, noise_shape, conds):
|
||||
cond_shapes = collections.defaultdict(list)
|
||||
cond_shapes_min = {}
|
||||
@@ -130,7 +172,8 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options)
|
||||
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||
real_model: BaseModel = None
|
||||
model.match_multigpu_clones()
|
||||
preprocess_multigpu_conds(conds, model, model_options)
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
@@ -182,3 +225,18 @@ def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
|
||||
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
|
||||
copy_dict1=False)
|
||||
return to_load_options
|
||||
|
||||
def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
|
||||
'''
|
||||
In case multigpu acceleration is enabled, prep ModelPatchers for each device.
|
||||
'''
|
||||
multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
|
||||
if len(multigpu_patchers) > 0:
|
||||
multigpu_dict: dict[torch.device, ModelPatcher] = {}
|
||||
multigpu_dict[model_patcher.load_device] = model_patcher
|
||||
for x in multigpu_patchers:
|
||||
x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
|
||||
x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
|
||||
multigpu_dict[x.load_device] = x
|
||||
model_options["multigpu_clones"] = multigpu_dict
|
||||
return multigpu_patchers
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import comfy.model_management
|
||||
from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .extra_samplers import uni_pc
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple, Any
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
@@ -20,6 +22,7 @@ import comfy.context_windows
|
||||
import comfy.utils
|
||||
import scipy.stats
|
||||
import numpy
|
||||
import threading
|
||||
|
||||
|
||||
def add_area_dims(area, num_dims):
|
||||
@@ -142,7 +145,7 @@ def can_concat_cond(c1, c2):
|
||||
|
||||
return cond_equal_size(c1.conditioning, c2.conditioning)
|
||||
|
||||
def cond_cat(c_list):
|
||||
def cond_cat(c_list, device=None):
|
||||
temp = {}
|
||||
for x in c_list:
|
||||
for k in x:
|
||||
@@ -154,6 +157,8 @@ def cond_cat(c_list):
|
||||
for k in temp:
|
||||
conds = temp[k]
|
||||
out[k] = conds[0].concat(conds[1:])
|
||||
if device is not None and hasattr(out[k], 'to'):
|
||||
out[k] = out[k].to(device)
|
||||
|
||||
return out
|
||||
|
||||
@@ -213,7 +218,9 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
if 'multigpu_clones' in model_options:
|
||||
return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
@@ -245,7 +252,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||
|
||||
model.current_patcher.prepare_state(timestep)
|
||||
model.current_patcher.prepare_state(timestep, model_options)
|
||||
|
||||
# run every hooked_to_run separately
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
@@ -346,6 +353,196 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
|
||||
return out_conds
|
||||
|
||||
def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
|
||||
default_conds = []
|
||||
has_default_conds = False
|
||||
|
||||
output_device = x_in.device
|
||||
|
||||
for i in range(len(conds)):
|
||||
out_conds.append(torch.zeros_like(x_in))
|
||||
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
||||
|
||||
cond = conds[i]
|
||||
default_c = []
|
||||
if cond is not None:
|
||||
for x in cond:
|
||||
if 'default' in x:
|
||||
default_c.append(x)
|
||||
has_default_conds = True
|
||||
continue
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
if p.hooks is not None:
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
default_conds.append(default_c)
|
||||
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||
|
||||
model.current_patcher.prepare_state(timestep, model_options)
|
||||
|
||||
devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()]
|
||||
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
|
||||
|
||||
total_conds = 0
|
||||
for to_run in hooked_to_run.values():
|
||||
total_conds += len(to_run)
|
||||
conds_per_device = max(1, math.ceil(total_conds//len(devices)))
|
||||
index_device = 0
|
||||
current_device = devices[index_device]
|
||||
# run every hooked_to_run separately
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
while len(to_run) > 0:
|
||||
current_device = devices[index_device % len(devices)]
|
||||
batched_to_run = device_batched_hooked_to_run.setdefault(current_device, [])
|
||||
# keep track of conds currently scheduled onto this device
|
||||
batched_to_run_length = 0
|
||||
for btr in batched_to_run:
|
||||
batched_to_run_length += len(btr[1])
|
||||
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
to_batch_temp = []
|
||||
# make sure not over conds_per_device limit when creating temp batch
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < (conds_per_device - batched_to_run_length):
|
||||
to_batch_temp += [x]
|
||||
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
free_memory = model_management.get_free_memory(current_device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
if model.memory_required(input_shape) * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
conds_to_batch = []
|
||||
for x in to_batch:
|
||||
conds_to_batch.append(to_run.pop(x))
|
||||
batched_to_run_length += len(conds_to_batch)
|
||||
|
||||
batched_to_run.append((hooks, conds_to_batch))
|
||||
if batched_to_run_length >= conds_per_device:
|
||||
index_device += 1
|
||||
|
||||
class thread_result(NamedTuple):
|
||||
output: Any
|
||||
mult: Any
|
||||
area: Any
|
||||
batch_chunks: int
|
||||
cond_or_uncond: Any
|
||||
error: Exception = None
|
||||
|
||||
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
|
||||
try:
|
||||
model_current: BaseModel = model_options["multigpu_clones"][device].model
|
||||
# run every hooked_to_run separately
|
||||
with torch.no_grad():
|
||||
for hooks, to_batch in batch_tuple:
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
uuids = []
|
||||
area = []
|
||||
control: ControlBase = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = x
|
||||
p = o[0]
|
||||
input_x.append(p.input_x)
|
||||
mult.append(p.mult)
|
||||
c.append(p.conditioning)
|
||||
area.append(p.area)
|
||||
cond_or_uncond.append(o[1])
|
||||
uuids.append(p.uuid)
|
||||
control = p.control
|
||||
patches = p.patches
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x).to(device)
|
||||
c = cond_cat(c, device=device)
|
||||
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
|
||||
|
||||
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
||||
model_options['transformer_options'],
|
||||
copy_dict1=False)
|
||||
|
||||
if patches is not None:
|
||||
transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
|
||||
transformer_options.get("patches", {}),
|
||||
patches
|
||||
)
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["uuids"] = uuids[:]
|
||||
transformer_options["sigmas"] = timestep
|
||||
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
|
||||
transformer_options["multigpu_thread_device"] = device
|
||||
|
||||
cast_transformer_options(transformer_options, device=device)
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if control is not None:
|
||||
device_control = control.get_instance_for_device(device)
|
||||
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
|
||||
else:
|
||||
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
|
||||
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
|
||||
except Exception as e:
|
||||
results.append(thread_result(None, None, None, None, None, error=e))
|
||||
raise
|
||||
|
||||
|
||||
results: list[thread_result] = []
|
||||
threads: list[threading.Thread] = []
|
||||
for device, batch_tuple in device_batched_hooked_to_run.items():
|
||||
new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results))
|
||||
threads.append(new_thread)
|
||||
new_thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
for output, mult, area, batch_chunks, cond_or_uncond, error in results:
|
||||
if error is not None:
|
||||
raise error
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
dims = len(a) // 2
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
|
||||
for i in range(len(out_conds)):
|
||||
out_conds[i] /= out_counts[i]
|
||||
|
||||
return out_conds
|
||||
|
||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
|
||||
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
|
||||
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
|
||||
@@ -650,6 +847,8 @@ def pre_run_control(model, conds):
|
||||
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
||||
if 'control' in x:
|
||||
x['control'].pre_run(model, percent_to_timestep_function)
|
||||
for device_cnet in x['control'].multigpu_clones.values():
|
||||
device_cnet.pre_run(model, percent_to_timestep_function)
|
||||
|
||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||
cond_cnets = []
|
||||
@@ -782,7 +981,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
||||
return KSAMPLER(sampler_function, extra_options, inpaint_options)
|
||||
|
||||
|
||||
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None, latent_shapes=None):
|
||||
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
|
||||
for k in conds:
|
||||
conds[k] = conds[k][:]
|
||||
resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
|
||||
@@ -792,7 +991,7 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
|
||||
|
||||
if hasattr(model, 'extra_conds'):
|
||||
for k in conds:
|
||||
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes)
|
||||
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
||||
|
||||
#make sure each cond area has an opposite one with the same area
|
||||
for k in conds:
|
||||
@@ -892,7 +1091,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
to_load_options = model_options.get("to_load_options", None)
|
||||
if to_load_options is None:
|
||||
return
|
||||
cast_transformer_options(to_load_options, device, dtype)
|
||||
|
||||
def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
|
||||
casts = []
|
||||
if device is not None:
|
||||
casts.append(device)
|
||||
@@ -901,18 +1102,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
# if nothing to apply, do nothing
|
||||
if len(casts) == 0:
|
||||
return
|
||||
|
||||
# try to call .to on patches
|
||||
if "patches" in to_load_options:
|
||||
patches = to_load_options["patches"]
|
||||
if "patches" in transformer_options:
|
||||
patches = transformer_options["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], "to"):
|
||||
for cast in casts:
|
||||
patch_list[i] = patch_list[i].to(cast)
|
||||
if "patches_replace" in to_load_options:
|
||||
patches = to_load_options["patches_replace"]
|
||||
if "patches_replace" in transformer_options:
|
||||
patches = transformer_options["patches_replace"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for k in patch_list:
|
||||
@@ -922,8 +1122,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
# try to call .to on any wrappers/callbacks
|
||||
wrappers_and_callbacks = ["wrappers", "callbacks"]
|
||||
for wc_name in wrappers_and_callbacks:
|
||||
if wc_name in to_load_options:
|
||||
wc: dict[str, list] = to_load_options[wc_name]
|
||||
if wc_name in transformer_options:
|
||||
wc: dict[str, list] = transformer_options[wc_name]
|
||||
for wc_dict in wc.values():
|
||||
for wc_list in wc_dict.values():
|
||||
for i in range(len(wc_list)):
|
||||
@@ -931,7 +1131,6 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
for cast in casts:
|
||||
wc_list[i] = wc_list[i].to(cast)
|
||||
|
||||
|
||||
class CFGGuider:
|
||||
def __init__(self, model_patcher: ModelPatcher):
|
||||
self.model_patcher = model_patcher
|
||||
@@ -962,11 +1161,11 @@ class CFGGuider:
|
||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
|
||||
|
||||
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None):
|
||||
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
|
||||
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
|
||||
latent_image = self.inner_model.process_latent_in(latent_image)
|
||||
|
||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes)
|
||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
||||
|
||||
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
||||
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
|
||||
@@ -980,10 +1179,12 @@ class CFGGuider:
|
||||
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||
|
||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None):
|
||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
|
||||
|
||||
if denoise_mask is not None:
|
||||
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
|
||||
|
||||
@@ -994,9 +1195,13 @@ class CFGGuider:
|
||||
|
||||
try:
|
||||
self.model_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
finally:
|
||||
self.model_patcher.cleanup()
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.cleanup()
|
||||
|
||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||
del self.inner_model
|
||||
@@ -1007,12 +1212,6 @@ class CFGGuider:
|
||||
if sigmas.shape[-1] == 0:
|
||||
return latent_image
|
||||
|
||||
if latent_image.is_nested:
|
||||
latent_image, latent_shapes = comfy.utils.pack_latents(latent_image.unbind())
|
||||
noise, _ = comfy.utils.pack_latents(noise.unbind())
|
||||
else:
|
||||
latent_shapes = [latent_image.shape]
|
||||
|
||||
self.conds = {}
|
||||
for k in self.original_conds:
|
||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||
@@ -1032,7 +1231,7 @@ class CFGGuider:
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
|
||||
)
|
||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
finally:
|
||||
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
|
||||
self.model_options = orig_model_options
|
||||
@@ -1040,9 +1239,6 @@ class CFGGuider:
|
||||
self.model_patcher.restore_hook_patches()
|
||||
|
||||
del self.conds
|
||||
|
||||
if len(latent_shapes) > 1:
|
||||
output = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(output, latent_shapes))
|
||||
return output
|
||||
|
||||
|
||||
|
||||
@@ -1106,25 +1106,3 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
|
||||
dim=1
|
||||
)
|
||||
return out
|
||||
|
||||
def pack_latents(latents):
|
||||
latent_shapes = []
|
||||
tensors = []
|
||||
for tensor in latents:
|
||||
latent_shapes.append(tensor.shape)
|
||||
tensors.append(tensor.reshape(tensor.shape[0], 1, -1))
|
||||
|
||||
latent = torch.cat(tensors, dim=-1)
|
||||
return latent, latent_shapes
|
||||
|
||||
def unpack_latents(combined_latent, latent_shapes):
|
||||
if len(latent_shapes) > 1:
|
||||
output_tensors = []
|
||||
for shape in latent_shapes:
|
||||
cut = math.prod(shape[1:])
|
||||
tens = combined_latent[:, :, :cut]
|
||||
combined_latent = combined_latent[:, :, cut:]
|
||||
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
|
||||
else:
|
||||
output_tensors = combined_latent
|
||||
return output_tensors
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
from __future__ import annotations
|
||||
import aiohttp
|
||||
import io
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.util import VideoContainer, VideoCodec
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy_api.input.basic_types import AudioInput
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiClient,
|
||||
ApiEndpoint,
|
||||
@@ -19,8 +26,43 @@ from PIL import Image
|
||||
import torch
|
||||
import math
|
||||
import base64
|
||||
from .util import tensor_to_bytesio, bytesio_to_image_tensor
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
import av
|
||||
|
||||
|
||||
async def download_url_to_video_output(
|
||||
video_url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
|
||||
) -> VideoFromFile:
|
||||
"""Downloads a video from a URL and returns a `VIDEO` output.
|
||||
|
||||
Args:
|
||||
video_url: The URL of the video to download.
|
||||
|
||||
Returns:
|
||||
A Comfy node `VIDEO` output.
|
||||
"""
|
||||
video_io = await download_url_to_bytesio(video_url, timeout, auth_kwargs=auth_kwargs)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {video_url}"
|
||||
logging.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
return VideoFromFile(video_io)
|
||||
|
||||
|
||||
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
||||
"""Downscale input image tensor to roughly the specified total pixels."""
|
||||
samples = image.movedim(-1, 1)
|
||||
total = int(total_pixels)
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
if scale_by >= 1:
|
||||
return image
|
||||
width = round(samples.shape[3] * scale_by)
|
||||
height = round(samples.shape[2] * scale_by)
|
||||
|
||||
s = common_upscale(samples, width, height, "lanczos", "disabled")
|
||||
s = s.movedim(1, -1)
|
||||
return s
|
||||
|
||||
|
||||
async def validate_and_cast_response(
|
||||
@@ -120,6 +162,11 @@ def validate_aspect_ratio(
|
||||
return aspect_ratio
|
||||
|
||||
|
||||
def mimetype_to_extension(mime_type: str) -> str:
|
||||
"""Converts a MIME type to a file extension."""
|
||||
return mime_type.split("/")[-1].lower()
|
||||
|
||||
|
||||
async def download_url_to_bytesio(
|
||||
url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
|
||||
) -> BytesIO:
|
||||
@@ -148,11 +195,136 @@ async def download_url_to_bytesio(
|
||||
return BytesIO(await resp.read())
|
||||
|
||||
|
||||
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
|
||||
"""Converts image data from BytesIO to a torch.Tensor.
|
||||
|
||||
Args:
|
||||
image_bytesio: BytesIO object containing the image data.
|
||||
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
|
||||
|
||||
Returns:
|
||||
A torch.Tensor representing the image (1, H, W, C).
|
||||
|
||||
Raises:
|
||||
PIL.UnidentifiedImageError: If the image data cannot be identified.
|
||||
ValueError: If the specified mode is invalid.
|
||||
"""
|
||||
image = Image.open(image_bytesio)
|
||||
image = image.convert(mode)
|
||||
image_array = np.array(image).astype(np.float32) / 255.0
|
||||
return torch.from_numpy(image_array).unsqueeze(0)
|
||||
|
||||
|
||||
async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
|
||||
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
|
||||
image_bytesio = await download_url_to_bytesio(url, timeout)
|
||||
return bytesio_to_image_tensor(image_bytesio)
|
||||
|
||||
|
||||
def process_image_response(response_content: bytes | str) -> torch.Tensor:
|
||||
"""Uses content from a Response object and converts it to a torch.Tensor"""
|
||||
return bytesio_to_image_tensor(BytesIO(response_content))
|
||||
|
||||
|
||||
def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
|
||||
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
|
||||
if len(image.shape) > 3:
|
||||
image = image[0]
|
||||
# TODO: remove alpha if not allowed and present
|
||||
input_tensor = image.cpu()
|
||||
input_tensor = downscale_image_tensor(
|
||||
input_tensor.unsqueeze(0), total_pixels=total_pixels
|
||||
).squeeze()
|
||||
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
return img
|
||||
|
||||
|
||||
def _pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
|
||||
"""Converts a PIL Image to a BytesIO object."""
|
||||
if not mime_type:
|
||||
mime_type = "image/png"
|
||||
|
||||
img_byte_arr = io.BytesIO()
|
||||
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
|
||||
pil_format = mime_type.split("/")[-1].upper()
|
||||
if pil_format == "JPG":
|
||||
pil_format = "JPEG"
|
||||
img.save(img_byte_arr, format=pil_format)
|
||||
img_byte_arr.seek(0)
|
||||
return img_byte_arr
|
||||
|
||||
|
||||
def tensor_to_bytesio(
|
||||
image: torch.Tensor,
|
||||
name: Optional[str] = None,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> BytesIO:
|
||||
"""Converts a torch.Tensor image to a named BytesIO object.
|
||||
|
||||
Args:
|
||||
image: Input torch.Tensor image.
|
||||
name: Optional filename for the BytesIO object.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||
|
||||
Returns:
|
||||
Named BytesIO object containing the image data, with pointer set to the start of buffer.
|
||||
"""
|
||||
if not mime_type:
|
||||
mime_type = "image/png"
|
||||
|
||||
pil_image = _tensor_to_pil(image, total_pixels=total_pixels)
|
||||
img_binary = _pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||
img_binary.name = (
|
||||
f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
|
||||
)
|
||||
return img_binary
|
||||
|
||||
|
||||
def tensor_to_base64_string(
|
||||
image_tensor: torch.Tensor,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> str:
|
||||
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
|
||||
|
||||
Args:
|
||||
image_tensor: Input torch.Tensor image.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||
|
||||
Returns:
|
||||
Base64 encoded string of the image.
|
||||
"""
|
||||
pil_image = _tensor_to_pil(image_tensor, total_pixels=total_pixels)
|
||||
img_byte_arr = _pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||
img_bytes = img_byte_arr.getvalue()
|
||||
# Encode bytes to base64 string
|
||||
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
|
||||
return base64_encoded_string
|
||||
|
||||
|
||||
def tensor_to_data_uri(
|
||||
image_tensor: torch.Tensor,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> str:
|
||||
"""Converts a tensor image to a Data URI string.
|
||||
|
||||
Args:
|
||||
image_tensor: Input torch.Tensor image.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
|
||||
|
||||
Returns:
|
||||
Data URI string (e.g., 'data:image/png;base64,...').
|
||||
"""
|
||||
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
|
||||
|
||||
def text_filepath_to_base64_string(filepath: str) -> str:
|
||||
"""Converts a text file to a base64 string."""
|
||||
with open(filepath, "rb") as f:
|
||||
@@ -207,6 +379,238 @@ async def upload_file_to_comfyapi(
|
||||
return response.download_url
|
||||
|
||||
|
||||
def video_to_base64_string(
|
||||
video: VideoInput,
|
||||
container_format: VideoContainer = None,
|
||||
codec: VideoCodec = None
|
||||
) -> str:
|
||||
"""
|
||||
Converts a video input to a base64 string.
|
||||
|
||||
Args:
|
||||
video: The video input to convert
|
||||
container_format: Optional container format to use (defaults to video.container if available)
|
||||
codec: Optional codec to use (defaults to video.codec if available)
|
||||
"""
|
||||
video_bytes_io = io.BytesIO()
|
||||
|
||||
# Use provided format/codec if specified, otherwise use video's own if available
|
||||
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
|
||||
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
|
||||
|
||||
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
|
||||
video_bytes_io.seek(0)
|
||||
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
async def upload_video_to_comfyapi(
|
||||
video: VideoInput,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
container: VideoContainer = VideoContainer.MP4,
|
||||
codec: VideoCodec = VideoCodec.H264,
|
||||
max_duration: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single video to ComfyUI API and returns its download URL.
|
||||
Uses the specified container and codec for saving the video before upload.
|
||||
|
||||
Args:
|
||||
video: VideoInput object (Comfy VIDEO type).
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
container: The video container format to use (default: MP4).
|
||||
codec: The video codec to use (default: H264).
|
||||
max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised.
|
||||
|
||||
Returns:
|
||||
The download URL for the uploaded video file.
|
||||
"""
|
||||
if max_duration is not None:
|
||||
try:
|
||||
actual_duration = video.duration_seconds
|
||||
if actual_duration is not None and actual_duration > max_duration:
|
||||
raise ValueError(
|
||||
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error("Error getting video duration: %s", str(e))
|
||||
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
||||
|
||||
upload_mime_type = f"video/{container.value.lower()}"
|
||||
filename = f"uploaded_video.{container.value.lower()}"
|
||||
|
||||
# Convert VideoInput to BytesIO using specified container/codec
|
||||
video_bytes_io = io.BytesIO()
|
||||
video.save_to(video_bytes_io, format=container, codec=codec)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs)
|
||||
|
||||
|
||||
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
|
||||
"""
|
||||
Prepares audio waveform for av library by converting to a contiguous numpy array.
|
||||
|
||||
Args:
|
||||
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
|
||||
|
||||
Returns:
|
||||
Contiguous numpy array of the audio waveform. If the audio was batched,
|
||||
the first item is taken.
|
||||
"""
|
||||
if waveform.ndim != 3 or waveform.shape[0] != 1:
|
||||
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
|
||||
|
||||
# If batch is > 1, take first item
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = waveform[0]
|
||||
|
||||
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
|
||||
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
|
||||
if audio_data_np.dtype != np.float32:
|
||||
audio_data_np = audio_data_np.astype(np.float32)
|
||||
|
||||
return audio_data_np
|
||||
|
||||
|
||||
def audio_ndarray_to_bytesio(
|
||||
audio_data_np: np.ndarray,
|
||||
sample_rate: int,
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
) -> BytesIO:
|
||||
"""
|
||||
Encodes a numpy array of audio data into a BytesIO object.
|
||||
"""
|
||||
audio_bytes_io = io.BytesIO()
|
||||
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
|
||||
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
|
||||
frame = av.AudioFrame.from_ndarray(
|
||||
audio_data_np,
|
||||
format="fltp",
|
||||
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
|
||||
)
|
||||
frame.sample_rate = sample_rate
|
||||
frame.pts = 0
|
||||
|
||||
for packet in audio_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
|
||||
# Flush stream
|
||||
for packet in audio_stream.encode(None):
|
||||
output_container.mux(packet)
|
||||
|
||||
audio_bytes_io.seek(0)
|
||||
return audio_bytes_io
|
||||
|
||||
|
||||
async def upload_audio_to_comfyapi(
|
||||
audio: AudioInput,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
mime_type: str = "audio/mp4",
|
||||
filename: str = "uploaded_audio.mp4",
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single audio input to ComfyUI API and returns its download URL.
|
||||
Encodes the raw waveform into the specified format before uploading.
|
||||
|
||||
Args:
|
||||
audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate)
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
|
||||
Returns:
|
||||
The download URL for the uploaded audio file.
|
||||
"""
|
||||
sample_rate: int = audio["sample_rate"]
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(
|
||||
audio_data_np, sample_rate, container_format, codec_name
|
||||
)
|
||||
|
||||
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
||||
|
||||
|
||||
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
|
||||
if wav.dtype.is_floating_point:
|
||||
return wav
|
||||
elif wav.dtype == torch.int16:
|
||||
return wav.float() / (2 ** 15)
|
||||
elif wav.dtype == torch.int32:
|
||||
return wav.float() / (2 ** 31)
|
||||
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
||||
|
||||
|
||||
def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict:
|
||||
"""
|
||||
Decode any common audio container from bytes using PyAV and return
|
||||
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
|
||||
"""
|
||||
with av.open(io.BytesIO(audio_bytes)) as af:
|
||||
if not af.streams.audio:
|
||||
raise ValueError("No audio stream found in response.")
|
||||
stream = af.streams.audio[0]
|
||||
|
||||
in_sr = int(stream.codec_context.sample_rate)
|
||||
out_sr = in_sr
|
||||
|
||||
frames: list[torch.Tensor] = []
|
||||
n_channels = stream.channels or 1
|
||||
|
||||
for frame in af.decode(streams=stream.index):
|
||||
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
|
||||
buf = torch.from_numpy(arr)
|
||||
if buf.ndim == 1:
|
||||
buf = buf.unsqueeze(0) # [T] -> [1, T]
|
||||
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
|
||||
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
|
||||
elif buf.shape[0] != n_channels:
|
||||
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
|
||||
frames.append(buf)
|
||||
|
||||
if not frames:
|
||||
raise ValueError("Decoded zero audio frames.")
|
||||
|
||||
wav = torch.cat(frames, dim=1) # [C, T]
|
||||
wav = f32_pcm(wav)
|
||||
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
|
||||
|
||||
|
||||
def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO:
|
||||
waveform = audio["waveform"].cpu()
|
||||
|
||||
output_buffer = io.BytesIO()
|
||||
output_container = av.open(output_buffer, mode='w', format="mp3")
|
||||
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
|
||||
out_stream.bit_rate = 320000
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
|
||||
frame.sample_rate = audio["sample_rate"]
|
||||
frame.pts = 0
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
output_container.mux(out_stream.encode(None))
|
||||
output_container.close()
|
||||
output_buffer.seek(0)
|
||||
return output_buffer
|
||||
|
||||
|
||||
def audio_to_base64_string(
|
||||
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
|
||||
) -> str:
|
||||
"""Converts an audio input to a base64 string."""
|
||||
sample_rate: int = audio["sample_rate"]
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(
|
||||
audio_data_np, sample_rate, container_format, codec_name
|
||||
)
|
||||
audio_bytes = audio_bytes_io.getvalue()
|
||||
return base64.b64encode(audio_bytes).decode("utf-8")
|
||||
|
||||
|
||||
async def upload_images_to_comfyapi(
|
||||
image: torch.Tensor,
|
||||
max_images=8,
|
||||
@@ -259,3 +663,56 @@ def resize_mask_to_image(
|
||||
if not allow_gradient:
|
||||
mask = (mask > 0.5).float()
|
||||
return mask
|
||||
|
||||
|
||||
def validate_string(
|
||||
string: str,
|
||||
strip_whitespace=True,
|
||||
field_name="prompt",
|
||||
min_length=None,
|
||||
max_length=None,
|
||||
):
|
||||
if string is None:
|
||||
raise Exception(f"Field '{field_name}' cannot be empty.")
|
||||
if strip_whitespace:
|
||||
string = string.strip()
|
||||
if min_length and len(string) < min_length:
|
||||
raise Exception(
|
||||
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
|
||||
)
|
||||
if max_length and len(string) > max_length:
|
||||
raise Exception(
|
||||
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
|
||||
)
|
||||
|
||||
|
||||
def image_tensor_pair_to_batch(
|
||||
image1: torch.Tensor, image2: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Converts a pair of image tensors to a batch tensor.
|
||||
If the images are not the same size, the smaller image is resized to
|
||||
match the larger image.
|
||||
"""
|
||||
if image1.shape[1:] != image2.shape[1:]:
|
||||
image2 = common_upscale(
|
||||
image2.movedim(-1, 1),
|
||||
image1.shape[2],
|
||||
image1.shape[1],
|
||||
"bilinear",
|
||||
"center",
|
||||
).movedim(1, -1)
|
||||
return torch.cat((image1, image2), dim=0)
|
||||
|
||||
|
||||
def get_size(path_or_object: Union[str, io.BytesIO]) -> int:
|
||||
if isinstance(path_or_object, str):
|
||||
return os.path.getsize(path_or_object)
|
||||
return len(path_or_object.getvalue())
|
||||
|
||||
|
||||
def validate_container_format_is_mp4(video: VideoInput) -> None:
|
||||
"""Validates video container format is MP4."""
|
||||
container_format = video.get_container_format()
|
||||
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
||||
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
|
||||
|
||||
@@ -50,6 +50,44 @@ class BFLFluxFillImageRequest(BaseModel):
|
||||
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
|
||||
|
||||
|
||||
class BFLFluxCannyImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='Text prompt for image generation')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
canny_low_threshold: Optional[int] = Field(None, description='Low threshold for Canny edge detection')
|
||||
canny_high_threshold: Optional[int] = Field(None, description='High threshold for Canny edge detection')
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
|
||||
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
|
||||
|
||||
|
||||
class BFLFluxDepthImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='Text prompt for image generation')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
|
||||
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
|
||||
|
||||
|
||||
class BFLFluxProGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
@@ -122,8 +160,15 @@ class BFLStatus(str, Enum):
|
||||
error = "Error"
|
||||
|
||||
|
||||
class BFLFluxStatusResponse(BaseModel):
|
||||
class BFLFluxProStatusResponse(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier for the generation task.")
|
||||
status: BFLStatus = Field(..., description="The status of the task.")
|
||||
result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).")
|
||||
progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0)
|
||||
result: Optional[Dict[str, Any]] = Field(
|
||||
None, description="The result of the task (null if not completed)."
|
||||
)
|
||||
progress: confloat(ge=0.0, le=1.0) = Field(
|
||||
..., description="The progress of the task (0.0 to 1.0)."
|
||||
)
|
||||
details: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Additional details about the task (null if not available)."
|
||||
)
|
||||
|
||||
@@ -1,20 +1,13 @@
|
||||
from __future__ import annotations
|
||||
from comfy_api_nodes.apis import (
|
||||
TripoModelVersion,
|
||||
TripoTextureQuality,
|
||||
)
|
||||
from enum import Enum
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
|
||||
class TripoModelVersion(str, Enum):
|
||||
v2_5_20250123 = 'v2.5-20250123'
|
||||
v2_0_20240919 = 'v2.0-20240919'
|
||||
v1_4_20240625 = 'v1.4-20240625'
|
||||
|
||||
|
||||
class TripoTextureQuality(str, Enum):
|
||||
standard = 'standard'
|
||||
detailed = 'detailed'
|
||||
|
||||
|
||||
class TripoStyle(str, Enum):
|
||||
PERSON_TO_CARTOON = "person:person2cartoon"
|
||||
ANIMAL_VENOM = "animal:venom"
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
from typing import Optional, Union
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Image2(BaseModel):
|
||||
bytesBase64Encoded: str
|
||||
gcsUri: Optional[str] = None
|
||||
mimeType: Optional[str] = None
|
||||
|
||||
|
||||
class Image3(BaseModel):
|
||||
bytesBase64Encoded: Optional[str] = None
|
||||
gcsUri: str
|
||||
mimeType: Optional[str] = None
|
||||
|
||||
|
||||
class Instance1(BaseModel):
|
||||
image: Optional[Union[Image2, Image3]] = Field(
|
||||
None, description='Optional image to guide video generation'
|
||||
)
|
||||
prompt: str = Field(..., description='Text description of the video')
|
||||
|
||||
|
||||
class PersonGeneration1(str, Enum):
|
||||
ALLOW = 'ALLOW'
|
||||
BLOCK = 'BLOCK'
|
||||
|
||||
|
||||
class Parameters1(BaseModel):
|
||||
aspectRatio: Optional[str] = Field(None, examples=['16:9'])
|
||||
durationSeconds: Optional[int] = None
|
||||
enhancePrompt: Optional[bool] = None
|
||||
generateAudio: Optional[bool] = Field(
|
||||
None,
|
||||
description='Generate audio for the video. Only supported by veo 3 models.',
|
||||
)
|
||||
negativePrompt: Optional[str] = None
|
||||
personGeneration: Optional[PersonGeneration1] = None
|
||||
sampleCount: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
storageUri: Optional[str] = Field(
|
||||
None, description='Optional Cloud Storage URI to upload the video'
|
||||
)
|
||||
|
||||
|
||||
class VeoGenVidRequest(BaseModel):
|
||||
instances: Optional[list[Instance1]] = None
|
||||
parameters: Optional[Parameters1] = None
|
||||
|
||||
|
||||
class VeoGenVidResponse(BaseModel):
|
||||
name: str = Field(
|
||||
...,
|
||||
description='Operation resource name',
|
||||
examples=[
|
||||
'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/a1b07c8e-7b5a-4aba-bb34-3e1ccb8afcc8'
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class VeoGenVidPollRequest(BaseModel):
|
||||
operationName: str = Field(
|
||||
...,
|
||||
description='Full operation name (from predict response)',
|
||||
examples=[
|
||||
'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/OPERATION_ID'
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class Video(BaseModel):
|
||||
bytesBase64Encoded: Optional[str] = Field(
|
||||
None, description='Base64-encoded video content'
|
||||
)
|
||||
gcsUri: Optional[str] = Field(None, description='Cloud Storage URI of the video')
|
||||
mimeType: Optional[str] = Field(None, description='Video MIME type')
|
||||
|
||||
|
||||
class Error1(BaseModel):
|
||||
code: Optional[int] = Field(None, description='Error code')
|
||||
message: Optional[str] = Field(None, description='Error message')
|
||||
|
||||
|
||||
class Response1(BaseModel):
|
||||
field_type: Optional[str] = Field(
|
||||
None,
|
||||
alias='@type',
|
||||
examples=[
|
||||
'type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse'
|
||||
],
|
||||
)
|
||||
raiMediaFilteredCount: Optional[int] = Field(
|
||||
None, description='Count of media filtered by responsible AI policies'
|
||||
)
|
||||
raiMediaFilteredReasons: Optional[list[str]] = Field(
|
||||
None, description='Reasons why media was filtered by responsible AI policies'
|
||||
)
|
||||
videos: Optional[list[Video]] = None
|
||||
|
||||
|
||||
class VeoGenVidPollResponse(BaseModel):
|
||||
done: Optional[bool] = None
|
||||
error: Optional[Error1] = Field(
|
||||
None, description='Error details if operation failed'
|
||||
)
|
||||
name: Optional[str] = None
|
||||
response: Optional[Response1] = Field(
|
||||
None, description='The actual prediction response if done is true'
|
||||
)
|
||||
@@ -1,43 +1,136 @@
|
||||
import asyncio
|
||||
import io
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing import Union, Optional
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
resize_mask_to_image,
|
||||
validate_aspect_ratio,
|
||||
)
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api_nodes.apis.bfl_api import (
|
||||
BFLStatus,
|
||||
BFLFluxExpandImageRequest,
|
||||
BFLFluxFillImageRequest,
|
||||
BFLFluxKontextProGenerateRequest,
|
||||
BFLFluxCannyImageRequest,
|
||||
BFLFluxDepthImageRequest,
|
||||
BFLFluxProGenerateRequest,
|
||||
BFLFluxProGenerateResponse,
|
||||
BFLFluxKontextProGenerateRequest,
|
||||
BFLFluxProUltraGenerateRequest,
|
||||
BFLFluxStatusResponse,
|
||||
BFLStatus,
|
||||
BFLFluxProGenerateResponse,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
downscale_image_tensor,
|
||||
validate_aspect_ratio,
|
||||
process_image_response,
|
||||
resize_mask_to_image,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import aiohttp
|
||||
import torch
|
||||
import base64
|
||||
import time
|
||||
from server import PromptServer
|
||||
|
||||
|
||||
def convert_mask_to_image(mask: torch.Tensor):
|
||||
"""
|
||||
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
|
||||
"""
|
||||
mask = mask.unsqueeze(-1)
|
||||
mask = torch.cat([mask] * 3, dim=-1)
|
||||
mask = torch.cat([mask]*3, dim=-1)
|
||||
return mask
|
||||
|
||||
|
||||
async def handle_bfl_synchronous_operation(
|
||||
operation: SynchronousOperation,
|
||||
timeout_bfl_calls=360,
|
||||
node_id: Union[str, None] = None,
|
||||
):
|
||||
response_api: BFLFluxProGenerateResponse = await operation.execute()
|
||||
return await _poll_until_generated(
|
||||
response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id
|
||||
)
|
||||
|
||||
|
||||
async def _poll_until_generated(
|
||||
polling_url: str, timeout=360, node_id: Union[str, None] = None
|
||||
):
|
||||
# used bfl-comfy-nodes to verify code implementation:
|
||||
# https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main
|
||||
start_time = time.time()
|
||||
retries_404 = 0
|
||||
max_retries_404 = 5
|
||||
retry_404_seconds = 2
|
||||
retry_202_seconds = 2
|
||||
retry_pending_seconds = 1
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
|
||||
while True:
|
||||
if node_id:
|
||||
time_elapsed = time.time() - start_time
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Generating ({time_elapsed:.0f}s)", node_id
|
||||
)
|
||||
|
||||
async with session.get(polling_url) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
if result["status"] == BFLStatus.ready:
|
||||
img_url = result["result"]["sample"]
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {img_url}", node_id
|
||||
)
|
||||
async with session.get(img_url) as img_resp:
|
||||
return process_image_response(await img_resp.content.read())
|
||||
elif result["status"] in [
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
]:
|
||||
status = result["status"]
|
||||
raise Exception(
|
||||
f"BFL API did not return an image due to: {status}."
|
||||
)
|
||||
elif result["status"] == BFLStatus.error:
|
||||
raise Exception(f"BFL API encountered an error: {result}.")
|
||||
elif result["status"] == BFLStatus.pending:
|
||||
await asyncio.sleep(retry_pending_seconds)
|
||||
continue
|
||||
elif response.status == 404:
|
||||
if retries_404 < max_retries_404:
|
||||
retries_404 += 1
|
||||
await asyncio.sleep(retry_404_seconds)
|
||||
continue
|
||||
raise Exception(
|
||||
f"BFL API could not find task after {max_retries_404} tries."
|
||||
)
|
||||
elif response.status == 202:
|
||||
await asyncio.sleep(retry_202_seconds)
|
||||
elif time.time() - start_time > timeout:
|
||||
raise Exception(
|
||||
f"BFL API experienced a timeout; could not return request under {timeout} seconds."
|
||||
)
|
||||
else:
|
||||
raise Exception(f"BFL API encountered an error: {response.json()}")
|
||||
|
||||
def convert_image_to_base64(image: torch.Tensor):
|
||||
scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048)
|
||||
# remove batch dimension if present
|
||||
if len(scaled_image.shape) > 3:
|
||||
scaled_image = scaled_image[0]
|
||||
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
img_byte_arr = io.BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
return base64.b64encode(img_byte_arr.getvalue()).decode()
|
||||
|
||||
|
||||
class FluxProUltraImageNode(IO.ComfyNode):
|
||||
"""
|
||||
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
|
||||
@@ -65,9 +158,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"prompt_upsampling",
|
||||
default=False,
|
||||
tooltip="Whether to perform upsampling on the prompt. "
|
||||
"If active, automatically modifies the prompt for more creative generation, "
|
||||
"but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
@@ -129,19 +220,22 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
||||
cls,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
prompt_upsampling: bool = False,
|
||||
raw: bool = False,
|
||||
seed: int = 0,
|
||||
image_prompt: Optional[torch.Tensor] = None,
|
||||
image_prompt_strength: float = 0.1,
|
||||
prompt_upsampling=False,
|
||||
raw=False,
|
||||
seed=0,
|
||||
image_prompt=None,
|
||||
image_prompt_strength=0.1,
|
||||
) -> IO.NodeOutput:
|
||||
if image_prompt is None:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bfl/flux-pro-1.1-ultra/generate", method="POST"),
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=BFLFluxProUltraGenerateRequest(
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.1-ultra/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxProUltraGenerateRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxProUltraGenerateRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
seed=seed,
|
||||
@@ -153,26 +247,22 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
||||
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
|
||||
),
|
||||
raw=raw,
|
||||
image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)),
|
||||
image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)),
|
||||
image_prompt=(
|
||||
image_prompt
|
||||
if image_prompt is None
|
||||
else convert_image_to_base64(image_prompt)
|
||||
),
|
||||
image_prompt_strength=(
|
||||
None if image_prompt is None else round(image_prompt_strength, 2)
|
||||
),
|
||||
),
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(initial_response.polling_url),
|
||||
response_model=BFLFluxStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
completed_statuses=[BFLStatus.ready],
|
||||
failed_statuses=[
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
BFLStatus.error,
|
||||
BFLStatus.task_not_found,
|
||||
],
|
||||
queued_statuses=[],
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
|
||||
return IO.NodeOutput(output_image)
|
||||
|
||||
|
||||
class FluxKontextProImageNode(IO.ComfyNode):
|
||||
@@ -257,7 +347,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
||||
aspect_ratio: str,
|
||||
guidance: float,
|
||||
steps: int,
|
||||
input_image: Optional[torch.Tensor] = None,
|
||||
input_image: Optional[torch.Tensor]=None,
|
||||
seed=0,
|
||||
prompt_upsampling=False,
|
||||
) -> IO.NodeOutput:
|
||||
@@ -270,36 +360,33 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
||||
)
|
||||
if input_image is None:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=cls.BFL_PATH, method="POST"),
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=BFLFluxKontextProGenerateRequest(
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=cls.BFL_PATH,
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxKontextProGenerateRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxKontextProGenerateRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
guidance=round(guidance, 1),
|
||||
steps=steps,
|
||||
seed=seed,
|
||||
aspect_ratio=aspect_ratio,
|
||||
input_image=(input_image if input_image is None else tensor_to_base64_string(input_image)),
|
||||
input_image=(
|
||||
input_image
|
||||
if input_image is None
|
||||
else convert_image_to_base64(input_image)
|
||||
)
|
||||
),
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(initial_response.polling_url),
|
||||
response_model=BFLFluxStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
completed_statuses=[BFLStatus.ready],
|
||||
failed_statuses=[
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
BFLStatus.error,
|
||||
BFLStatus.task_not_found,
|
||||
],
|
||||
queued_statuses=[],
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
|
||||
return IO.NodeOutput(output_image)
|
||||
|
||||
|
||||
class FluxKontextMaxImageNode(FluxKontextProImageNode):
|
||||
@@ -335,9 +422,7 @@ class FluxProImageNode(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"prompt_upsampling",
|
||||
default=False,
|
||||
tooltip="Whether to perform upsampling on the prompt. "
|
||||
"If active, automatically modifies the prompt for more creative generation, "
|
||||
"but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"width",
|
||||
@@ -396,15 +481,20 @@ class FluxProImageNode(IO.ComfyNode):
|
||||
image_prompt=None,
|
||||
# image_prompt_strength=0.1,
|
||||
) -> IO.NodeOutput:
|
||||
image_prompt = image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
image_prompt = (
|
||||
image_prompt
|
||||
if image_prompt is None
|
||||
else convert_image_to_base64(image_prompt)
|
||||
)
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.1/generate",
|
||||
method="POST",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxProGenerateRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=BFLFluxProGenerateRequest(
|
||||
request=BFLFluxProGenerateRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
width=width,
|
||||
@@ -412,23 +502,13 @@ class FluxProImageNode(IO.ComfyNode):
|
||||
seed=seed,
|
||||
image_prompt=image_prompt,
|
||||
),
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(initial_response.polling_url),
|
||||
response_model=BFLFluxStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
completed_statuses=[BFLStatus.ready],
|
||||
failed_statuses=[
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
BFLStatus.error,
|
||||
BFLStatus.task_not_found,
|
||||
],
|
||||
queued_statuses=[],
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
|
||||
return IO.NodeOutput(output_image)
|
||||
|
||||
|
||||
class FluxProExpandNode(IO.ComfyNode):
|
||||
@@ -454,9 +534,7 @@ class FluxProExpandNode(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"prompt_upsampling",
|
||||
default=False,
|
||||
tooltip="Whether to perform upsampling on the prompt. "
|
||||
"If active, automatically modifies the prompt for more creative generation, "
|
||||
"but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"top",
|
||||
@@ -532,11 +610,16 @@ class FluxProExpandNode(IO.ComfyNode):
|
||||
guidance: float,
|
||||
seed=0,
|
||||
) -> IO.NodeOutput:
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bfl/flux-pro-1.0-expand/generate", method="POST"),
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=BFLFluxExpandImageRequest(
|
||||
image = convert_image_to_base64(image)
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.0-expand/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxExpandImageRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxExpandImageRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
top=top,
|
||||
@@ -546,25 +629,16 @@ class FluxProExpandNode(IO.ComfyNode):
|
||||
steps=steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
image=tensor_to_base64_string(image),
|
||||
image=image,
|
||||
),
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(initial_response.polling_url),
|
||||
response_model=BFLFluxStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
completed_statuses=[BFLStatus.ready],
|
||||
failed_statuses=[
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
BFLStatus.error,
|
||||
BFLStatus.task_not_found,
|
||||
],
|
||||
queued_statuses=[],
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
|
||||
return IO.NodeOutput(output_image)
|
||||
|
||||
|
||||
|
||||
class FluxProFillNode(IO.ComfyNode):
|
||||
@@ -591,9 +665,7 @@ class FluxProFillNode(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"prompt_upsampling",
|
||||
default=False,
|
||||
tooltip="Whether to perform upsampling on the prompt. "
|
||||
"If active, automatically modifies the prompt for more creative generation, "
|
||||
"but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"guidance",
|
||||
@@ -640,37 +712,272 @@ class FluxProFillNode(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
# prepare mask
|
||||
mask = resize_mask_to_image(mask, image)
|
||||
mask = tensor_to_base64_string(convert_mask_to_image(mask))
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bfl/flux-pro-1.0-fill/generate", method="POST"),
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=BFLFluxFillImageRequest(
|
||||
mask = convert_image_to_base64(convert_mask_to_image(mask))
|
||||
# make sure image will have alpha channel removed
|
||||
image = convert_image_to_base64(image[:, :, :, :3])
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.0-fill/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxFillImageRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxFillImageRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
steps=steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed
|
||||
image=image,
|
||||
mask=mask,
|
||||
),
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(initial_response.polling_url),
|
||||
response_model=BFLFluxStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
completed_statuses=[BFLStatus.ready],
|
||||
failed_statuses=[
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
BFLStatus.error,
|
||||
BFLStatus.task_not_found,
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
|
||||
return IO.NodeOutput(output_image)
|
||||
|
||||
|
||||
class FluxProCannyNode(IO.ComfyNode):
|
||||
"""
|
||||
Generate image using a control image (canny).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="FluxProCannyNode",
|
||||
display_name="Flux.1 Canny Control Image",
|
||||
category="api node/image/BFL",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("control_image"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"prompt_upsampling",
|
||||
default=False,
|
||||
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"canny_low_threshold",
|
||||
default=0.1,
|
||||
min=0.01,
|
||||
max=0.99,
|
||||
step=0.01,
|
||||
tooltip="Low threshold for Canny edge detection; ignored if skip_processing is True",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"canny_high_threshold",
|
||||
default=0.4,
|
||||
min=0.01,
|
||||
max=0.99,
|
||||
step=0.01,
|
||||
tooltip="High threshold for Canny edge detection; ignored if skip_processing is True",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"skip_preprocessing",
|
||||
default=False,
|
||||
tooltip="Whether to skip preprocessing; set to True if control_image already is canny-fied, False if it is a raw image.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"guidance",
|
||||
default=30,
|
||||
min=1,
|
||||
max=100,
|
||||
tooltip="Guidance strength for the image generation process",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=50,
|
||||
min=15,
|
||||
max=50,
|
||||
tooltip="Number of steps for the image generation process",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
],
|
||||
queued_statuses=[],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
control_image: torch.Tensor,
|
||||
prompt: str,
|
||||
prompt_upsampling: bool,
|
||||
canny_low_threshold: float,
|
||||
canny_high_threshold: float,
|
||||
skip_preprocessing: bool,
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed=0,
|
||||
) -> IO.NodeOutput:
|
||||
control_image = convert_image_to_base64(control_image[:, :, :, :3])
|
||||
preprocessed_image = None
|
||||
|
||||
# scale canny threshold between 0-500, to match BFL's API
|
||||
def scale_value(value: float, min_val=0, max_val=500):
|
||||
return min_val + value * (max_val - min_val)
|
||||
canny_low_threshold = int(round(scale_value(canny_low_threshold)))
|
||||
canny_high_threshold = int(round(scale_value(canny_high_threshold)))
|
||||
|
||||
|
||||
if skip_preprocessing:
|
||||
preprocessed_image = control_image
|
||||
control_image = None
|
||||
canny_low_threshold = None
|
||||
canny_high_threshold = None
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.0-canny/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxCannyImageRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxCannyImageRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
steps=steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
control_image=control_image,
|
||||
canny_low_threshold=canny_low_threshold,
|
||||
canny_high_threshold=canny_high_threshold,
|
||||
preprocessed_image=preprocessed_image,
|
||||
),
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
|
||||
return IO.NodeOutput(output_image)
|
||||
|
||||
|
||||
class FluxProDepthNode(IO.ComfyNode):
|
||||
"""
|
||||
Generate image using a control image (depth).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="FluxProDepthNode",
|
||||
display_name="Flux.1 Depth Control Image",
|
||||
category="api node/image/BFL",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("control_image"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"prompt_upsampling",
|
||||
default=False,
|
||||
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"skip_preprocessing",
|
||||
default=False,
|
||||
tooltip="Whether to skip preprocessing; set to True if control_image already is depth-ified, False if it is a raw image.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"guidance",
|
||||
default=15,
|
||||
min=1,
|
||||
max=100,
|
||||
tooltip="Guidance strength for the image generation process",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=50,
|
||||
min=15,
|
||||
max=50,
|
||||
tooltip="Number of steps for the image generation process",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
control_image: torch.Tensor,
|
||||
prompt: str,
|
||||
prompt_upsampling: bool,
|
||||
skip_preprocessing: bool,
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed=0,
|
||||
) -> IO.NodeOutput:
|
||||
control_image = convert_image_to_base64(control_image[:,:,:,:3])
|
||||
preprocessed_image = None
|
||||
|
||||
if skip_preprocessing:
|
||||
preprocessed_image = control_image
|
||||
control_image = None
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.0-depth/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxDepthImageRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxDepthImageRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
steps=steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
control_image=control_image,
|
||||
preprocessed_image=preprocessed_image,
|
||||
),
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
|
||||
return IO.NodeOutput(output_image)
|
||||
|
||||
|
||||
class BFLExtension(ComfyExtension):
|
||||
@@ -683,6 +990,8 @@ class BFLExtension(ComfyExtension):
|
||||
FluxKontextMaxImageNode,
|
||||
FluxProExpandNode,
|
||||
FluxProFillNode,
|
||||
FluxProCannyNode,
|
||||
FluxProDepthNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,27 +1,35 @@
|
||||
import logging
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal, Optional, Type, Union
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.util import (
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api_nodes.util.validation_utils import (
|
||||
validate_image_aspect_ratio_range,
|
||||
get_number_of_images,
|
||||
validate_image_dimensions,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
EmptyRequest,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
T,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
get_number_of_images,
|
||||
image_tensor_pair_to_batch,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
validate_image_aspect_ratio_range,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
image_tensor_pair_to_batch,
|
||||
)
|
||||
|
||||
|
||||
BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations"
|
||||
|
||||
# Long-running tasks endpoints(e.g., video)
|
||||
@@ -38,14 +46,13 @@ class Image2ImageModelName(str, Enum):
|
||||
|
||||
|
||||
class Text2VideoModelName(str, Enum):
|
||||
seedance_1_pro = "seedance-1-0-pro-250528"
|
||||
seedance_1_pro = "seedance-1-0-pro-250528"
|
||||
seedance_1_lite = "seedance-1-0-lite-t2v-250428"
|
||||
|
||||
|
||||
class Image2VideoModelName(str, Enum):
|
||||
"""note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757"""
|
||||
|
||||
seedance_1_pro = "seedance-1-0-pro-250528"
|
||||
seedance_1_pro = "seedance-1-0-pro-250528"
|
||||
seedance_1_lite = "seedance-1-0-lite-i2v-250428"
|
||||
|
||||
|
||||
@@ -201,6 +208,35 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
|
||||
return None
|
||||
|
||||
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
task_id: str,
|
||||
estimated_duration: Optional[int] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> TaskStatusResponse:
|
||||
"""Polls the ByteDance API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return await PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
completed_statuses=[
|
||||
"succeeded",
|
||||
],
|
||||
failed_statuses=[
|
||||
"cancelled",
|
||||
"failed",
|
||||
],
|
||||
status_extractor=lambda response: response.status,
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=get_video_url_from_task_status,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
).execute()
|
||||
|
||||
|
||||
class ByteDanceImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@@ -267,7 +303,7 @@ class ByteDanceImageNode(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the image',
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the image",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@@ -305,7 +341,8 @@ class ByteDanceImageNode(IO.ComfyNode):
|
||||
w, h = width, height
|
||||
if not (512 <= w <= 2048) or not (512 <= h <= 2048):
|
||||
raise ValueError(
|
||||
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 512 and 2048 pixels."
|
||||
f"Custom size out of range: {w}x{h}. "
|
||||
"Both width and height must be between 512 and 2048 pixels."
|
||||
)
|
||||
|
||||
payload = Text2ImageTaskCreationRequest(
|
||||
@@ -316,12 +353,20 @@ class ByteDanceImageNode(IO.ComfyNode):
|
||||
guidance_scale=guidance_scale,
|
||||
watermark=watermark,
|
||||
)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"),
|
||||
data=payload,
|
||||
response_model=ImageTaskCreationResponse,
|
||||
)
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=BYTEPLUS_IMAGE_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=Text2ImageTaskCreationRequest,
|
||||
response_model=ImageTaskCreationResponse,
|
||||
),
|
||||
request=payload,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response)))
|
||||
|
||||
|
||||
@@ -375,7 +420,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the image',
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the image",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@@ -404,7 +449,16 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Exactly one input image is required.")
|
||||
validate_image_aspect_ratio_range(image, (1, 3), (3, 1))
|
||||
source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0]
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
source_url = (await upload_images_to_comfyapi(
|
||||
image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth_kwargs,
|
||||
))[0]
|
||||
payload = Image2ImageTaskCreationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
@@ -413,12 +467,16 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
||||
guidance_scale=guidance_scale,
|
||||
watermark=watermark,
|
||||
)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"),
|
||||
data=payload,
|
||||
response_model=ImageTaskCreationResponse,
|
||||
)
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=BYTEPLUS_IMAGE_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=Image2ImageTaskCreationRequest,
|
||||
response_model=ImageTaskCreationResponse,
|
||||
),
|
||||
request=payload,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response)))
|
||||
|
||||
|
||||
@@ -446,7 +504,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="Input image(s) for image-to-image generation. "
|
||||
"List of 1-10 images for single or multi-reference generation.",
|
||||
"List of 1-10 images for single or multi-reference generation.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
@@ -476,9 +534,9 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
"sequential_image_generation",
|
||||
options=["disabled", "auto"],
|
||||
tooltip="Group image generation mode. "
|
||||
"'disabled' generates a single image. "
|
||||
"'auto' lets the model decide whether to generate multiple related images "
|
||||
"(e.g., story scenes, character variations).",
|
||||
"'disabled' generates a single image. "
|
||||
"'auto' lets the model decide whether to generate multiple related images "
|
||||
"(e.g., story scenes, character variations).",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
@@ -489,7 +547,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Maximum number of images to generate when sequential_image_generation='auto'. "
|
||||
"Total images (input + generated) cannot exceed 15.",
|
||||
"Total images (input + generated) cannot exceed 15.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
@@ -506,7 +564,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the image.',
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the image.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
@@ -553,7 +611,8 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
w, h = width, height
|
||||
if not (1024 <= w <= 4096) or not (1024 <= h <= 4096):
|
||||
raise ValueError(
|
||||
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
|
||||
f"Custom size out of range: {w}x{h}. "
|
||||
"Both width and height must be between 1024 and 4096 pixels."
|
||||
)
|
||||
n_input_images = get_number_of_images(image) if image is not None else 0
|
||||
if n_input_images > 10:
|
||||
@@ -562,31 +621,41 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
raise ValueError(
|
||||
"The maximum number of generated images plus the number of reference images cannot exceed 15."
|
||||
)
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
reference_images_urls = []
|
||||
if n_input_images:
|
||||
for i in image:
|
||||
validate_image_aspect_ratio_range(i, (1, 3), (3, 1))
|
||||
reference_images_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
reference_images_urls = (await upload_images_to_comfyapi(
|
||||
image,
|
||||
max_images=n_input_images,
|
||||
mime_type="image/png",
|
||||
)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"),
|
||||
response_model=ImageTaskCreationResponse,
|
||||
data=Seedream4TaskCreationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
image=reference_images_urls,
|
||||
size=f"{w}x{h}",
|
||||
seed=seed,
|
||||
sequential_image_generation=sequential_image_generation,
|
||||
sequential_image_generation_options=Seedream4Options(max_images=max_images),
|
||||
watermark=watermark,
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
))
|
||||
payload = Seedream4TaskCreationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
image=reference_images_urls,
|
||||
size=f"{w}x{h}",
|
||||
seed=seed,
|
||||
sequential_image_generation=sequential_image_generation,
|
||||
sequential_image_generation_options=Seedream4Options(max_images=max_images),
|
||||
watermark=watermark,
|
||||
)
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=BYTEPLUS_IMAGE_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=Seedream4TaskCreationRequest,
|
||||
response_model=ImageTaskCreationResponse,
|
||||
),
|
||||
request=payload,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
|
||||
if len(response.data) == 1:
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response)))
|
||||
urls = [str(d["url"]) for d in response.data if isinstance(d, dict) and "url" in d]
|
||||
@@ -650,13 +719,13 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
||||
"camera_fixed",
|
||||
default=False,
|
||||
tooltip="Specifies whether to fix the camera. The platform appends an instruction "
|
||||
"to fix the camera to your prompt, but does not guarantee the actual effect.",
|
||||
"to fix the camera to your prompt, but does not guarantee the actual effect.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the video.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@@ -695,9 +764,19 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
||||
f"--camerafixed {str(camera_fixed).lower()} "
|
||||
f"--watermark {str(watermark).lower()}"
|
||||
)
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
return await process_video_task(
|
||||
cls,
|
||||
payload=Text2VideoTaskCreationRequest(model=model, content=[TaskTextContent(text=prompt)]),
|
||||
request_model=Text2VideoTaskCreationRequest,
|
||||
payload=Text2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
content=[TaskTextContent(text=prompt)],
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
|
||||
)
|
||||
|
||||
@@ -761,13 +840,13 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
"camera_fixed",
|
||||
default=False,
|
||||
tooltip="Specifies whether to fix the camera. The platform appends an instruction "
|
||||
"to fix the camera to your prompt, but does not guarantee the actual effect.",
|
||||
"to fix the camera to your prompt, but does not guarantee the actual effect.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the video.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@@ -800,7 +879,13 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
|
||||
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||
|
||||
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth_kwargs))[0]
|
||||
|
||||
prompt = (
|
||||
f"{prompt} "
|
||||
f"--resolution {resolution} "
|
||||
@@ -812,11 +897,13 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
)
|
||||
|
||||
return await process_video_task(
|
||||
cls,
|
||||
request_model=Image2VideoTaskCreationRequest,
|
||||
payload=Image2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
content=[TaskTextContent(text=prompt), TaskImageContent(image_url=TaskImageContentUrl(url=image_url))],
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
|
||||
)
|
||||
|
||||
@@ -884,13 +971,13 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
"camera_fixed",
|
||||
default=False,
|
||||
tooltip="Specifies whether to fix the camera. The platform appends an instruction "
|
||||
"to fix the camera to your prompt, but does not guarantee the actual effect.",
|
||||
"to fix the camera to your prompt, but does not guarantee the actual effect.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the video.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@@ -925,11 +1012,16 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000)
|
||||
validate_image_aspect_ratio_range(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
image_tensor_pair_to_batch(first_frame, last_frame),
|
||||
max_images=2,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
prompt = (
|
||||
@@ -943,7 +1035,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
)
|
||||
|
||||
return await process_video_task(
|
||||
cls,
|
||||
request_model=Image2VideoTaskCreationRequest,
|
||||
payload=Image2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
content=[
|
||||
@@ -952,6 +1044,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[1])), role="last_frame"),
|
||||
],
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
|
||||
)
|
||||
|
||||
@@ -1014,7 +1108,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the video.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@@ -1047,7 +1141,15 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
|
||||
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||
|
||||
image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png")
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
image_urls = await upload_images_to_comfyapi(
|
||||
images, max_images=4, mime_type="image/png", auth_kwargs=auth_kwargs
|
||||
)
|
||||
|
||||
prompt = (
|
||||
f"{prompt} "
|
||||
f"--resolution {resolution} "
|
||||
@@ -1058,32 +1160,42 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
)
|
||||
x = [
|
||||
TaskTextContent(text=prompt),
|
||||
*[TaskImageContent(image_url=TaskImageContentUrl(url=str(i)), role="reference_image") for i in image_urls],
|
||||
*[TaskImageContent(image_url=TaskImageContentUrl(url=str(i)), role="reference_image") for i in image_urls]
|
||||
]
|
||||
return await process_video_task(
|
||||
cls,
|
||||
payload=Image2VideoTaskCreationRequest(model=model, content=x),
|
||||
request_model=Image2VideoTaskCreationRequest,
|
||||
payload=Image2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
content=x,
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
|
||||
)
|
||||
|
||||
|
||||
async def process_video_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
request_model: Type[T],
|
||||
payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
|
||||
auth_kwargs: dict,
|
||||
node_id: str,
|
||||
estimated_duration: Optional[int],
|
||||
) -> IO.NodeOutput:
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||
data=payload,
|
||||
response_model=TaskCreationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
initial_response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=BYTEPLUS_TASK_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=request_model,
|
||||
response_model=TaskCreationResponse,
|
||||
),
|
||||
request=payload,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
response = await poll_until_finished(
|
||||
auth_kwargs,
|
||||
initial_response.id,
|
||||
estimated_duration=estimated_duration,
|
||||
response_model=TaskStatusResponse,
|
||||
node_id=node_id,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response)))
|
||||
|
||||
@@ -1109,6 +1221,5 @@ class ByteDanceExtension(ComfyExtension):
|
||||
ByteDanceImageReferenceNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ByteDanceExtension:
|
||||
return ByteDanceExtension()
|
||||
|
||||
@@ -2,47 +2,45 @@
|
||||
API Nodes for Gemini Multimodal LLM Usage via Remote API
|
||||
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import os
|
||||
import uuid
|
||||
from enum import Enum
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from typing import Literal, Optional
|
||||
from enum import Enum
|
||||
from typing import Optional, Literal
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import folder_paths
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api.util import VideoCodec, VideoContainer
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||
from server import PromptServer
|
||||
from comfy_api_nodes.apis import (
|
||||
GeminiContent,
|
||||
GeminiGenerateContentRequest,
|
||||
GeminiGenerateContentResponse,
|
||||
GeminiInlineData,
|
||||
GeminiMimeType,
|
||||
GeminiPart,
|
||||
GeminiMimeType,
|
||||
)
|
||||
from comfy_api_nodes.apis.gemini_api import (
|
||||
GeminiImageConfig,
|
||||
GeminiImageGenerateContentRequest,
|
||||
GeminiImageGenerationConfig,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest, GeminiImageConfig
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
audio_to_base64_string,
|
||||
bytesio_to_image_tensor,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
validate_string,
|
||||
video_to_base64_string,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
)
|
||||
from server import PromptServer
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
validate_string,
|
||||
audio_to_base64_string,
|
||||
video_to_base64_string,
|
||||
tensor_to_base64_string,
|
||||
bytesio_to_image_tensor,
|
||||
)
|
||||
from comfy_api.util import VideoContainer, VideoCodec
|
||||
|
||||
|
||||
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
||||
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
||||
@@ -68,6 +66,50 @@ class GeminiImageModel(str, Enum):
|
||||
gemini_2_5_flash_image = "gemini-2.5-flash-image"
|
||||
|
||||
|
||||
def get_gemini_endpoint(
|
||||
model: GeminiModel,
|
||||
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
|
||||
"""
|
||||
Get the API endpoint for a given Gemini model.
|
||||
|
||||
Args:
|
||||
model: The Gemini model to use, either as enum or string value.
|
||||
|
||||
Returns:
|
||||
ApiEndpoint configured for the specific Gemini model.
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
model = GeminiModel(model)
|
||||
return ApiEndpoint(
|
||||
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
|
||||
method=HttpMethod.POST,
|
||||
request_model=GeminiGenerateContentRequest,
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
)
|
||||
|
||||
|
||||
def get_gemini_image_endpoint(
|
||||
model: GeminiImageModel,
|
||||
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
|
||||
"""
|
||||
Get the API endpoint for a given Gemini model.
|
||||
|
||||
Args:
|
||||
model: The Gemini model to use, either as enum or string value.
|
||||
|
||||
Returns:
|
||||
ApiEndpoint configured for the specific Gemini model.
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
model = GeminiImageModel(model)
|
||||
return ApiEndpoint(
|
||||
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
|
||||
method=HttpMethod.POST,
|
||||
request_model=GeminiImageGenerateContentRequest,
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
)
|
||||
|
||||
|
||||
def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert image tensor input to Gemini API compatible parts.
|
||||
@@ -80,7 +122,9 @@ def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
|
||||
"""
|
||||
image_parts: list[GeminiPart] = []
|
||||
for image_index in range(image_input.shape[0]):
|
||||
image_as_b64 = tensor_to_base64_string(image_input[image_index].unsqueeze(0))
|
||||
image_as_b64 = tensor_to_base64_string(
|
||||
image_input[image_index].unsqueeze(0)
|
||||
)
|
||||
image_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
@@ -92,7 +136,37 @@ def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
|
||||
return image_parts
|
||||
|
||||
|
||||
def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]:
|
||||
def create_text_part(text: str) -> GeminiPart:
|
||||
"""
|
||||
Create a text part for the Gemini API request.
|
||||
|
||||
Args:
|
||||
text: The text content to include in the request.
|
||||
|
||||
Returns:
|
||||
A GeminiPart object with the text content.
|
||||
"""
|
||||
return GeminiPart(text=text)
|
||||
|
||||
|
||||
def get_parts_from_response(
|
||||
response: GeminiGenerateContentResponse
|
||||
) -> list[GeminiPart]:
|
||||
"""
|
||||
Extract all parts from the Gemini API response.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
|
||||
Returns:
|
||||
List of response parts from the first candidate.
|
||||
"""
|
||||
return response.candidates[0].content.parts
|
||||
|
||||
|
||||
def get_parts_by_type(
|
||||
response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
|
||||
) -> list[GeminiPart]:
|
||||
"""
|
||||
Filter response parts by their type.
|
||||
|
||||
@@ -104,10 +178,14 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
|
||||
List of response parts matching the requested type.
|
||||
"""
|
||||
parts = []
|
||||
for part in response.candidates[0].content.parts:
|
||||
for part in get_parts_from_response(response):
|
||||
if part_type == "text" and hasattr(part, "text") and part.text:
|
||||
parts.append(part)
|
||||
elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type:
|
||||
elif (
|
||||
hasattr(part, "inlineData")
|
||||
and part.inlineData
|
||||
and part.inlineData.mimeType == part_type
|
||||
):
|
||||
parts.append(part)
|
||||
# Skip parts that don't match the requested type
|
||||
return parts
|
||||
@@ -135,11 +213,11 @@ def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Te
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
image_tensors.append(returned_image)
|
||||
if len(image_tensors) == 0:
|
||||
return torch.zeros((1, 1024, 1024, 4))
|
||||
return torch.zeros((1,1024,1024,4))
|
||||
return torch.cat(image_tensors, dim=0)
|
||||
|
||||
|
||||
class GeminiNode(IO.ComfyNode):
|
||||
class GeminiNode(ComfyNodeABC):
|
||||
"""
|
||||
Node to generate text responses from a Gemini model.
|
||||
|
||||
@@ -150,79 +228,96 @@ class GeminiNode(IO.ComfyNode):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GeminiNode",
|
||||
display_name="Google Gemini",
|
||||
category="api node/text/Gemini",
|
||||
description="Generate text responses with Google's Gemini AI model. "
|
||||
"You can provide multiple types of inputs (text, images, audio, video) "
|
||||
"as context for generating more relevant and meaningful responses.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text inputs to the model, used to generate a response. "
|
||||
"You can include detailed instructions, questions, or context for the model.",
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text inputs to the model, used to generate a response. You can include detailed instructions, questions, or context for the model.",
|
||||
},
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=GeminiModel,
|
||||
default=GeminiModel.gemini_2_5_pro,
|
||||
tooltip="The Gemini model to use for generating responses.",
|
||||
"model": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "The Gemini model to use for generating responses.",
|
||||
"options": [model.value for model in GeminiModel],
|
||||
"default": GeminiModel.gemini_2_5_pro.value,
|
||||
},
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=42,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="When seed is fixed to a specific value, the model makes a best effort to provide "
|
||||
"the same response for repeated requests. Deterministic output isn't guaranteed. "
|
||||
"Also, changing the model or parameter settings, such as the temperature, "
|
||||
"can cause variations in the response even when you use the same seed value. "
|
||||
"By default, a random seed value is used.",
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 42,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
|
||||
},
|
||||
),
|
||||
IO.Image.Input(
|
||||
"images",
|
||||
optional=True,
|
||||
tooltip="Optional image(s) to use as context for the model. "
|
||||
"To include multiple images, you can use the Batch Images node.",
|
||||
},
|
||||
"optional": {
|
||||
"images": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
|
||||
},
|
||||
),
|
||||
IO.Audio.Input(
|
||||
"audio",
|
||||
optional=True,
|
||||
tooltip="Optional audio to use as context for the model.",
|
||||
"audio": (
|
||||
IO.AUDIO,
|
||||
{
|
||||
"tooltip": "Optional audio to use as context for the model.",
|
||||
"default": None,
|
||||
},
|
||||
),
|
||||
IO.Video.Input(
|
||||
"video",
|
||||
optional=True,
|
||||
tooltip="Optional video to use as context for the model.",
|
||||
"video": (
|
||||
IO.VIDEO,
|
||||
{
|
||||
"tooltip": "Optional video to use as context for the model.",
|
||||
"default": None,
|
||||
},
|
||||
),
|
||||
IO.Custom("GEMINI_INPUT_FILES").Input(
|
||||
"files",
|
||||
optional=True,
|
||||
tooltip="Optional file(s) to use as context for the model. "
|
||||
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
"files": (
|
||||
"GEMINI_INPUT_FILES",
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
},
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Generate text responses with Google's Gemini AI model. You can provide multiple types of inputs (text, images, audio, video) as context for generating more relevant and meaningful responses."
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/text/Gemini"
|
||||
API_NODE = True
|
||||
|
||||
def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert video input to Gemini API compatible parts.
|
||||
|
||||
Args:
|
||||
video_input: Video tensor from ComfyUI.
|
||||
**kwargs: Additional arguments to pass to the conversion function.
|
||||
|
||||
Returns:
|
||||
List of GeminiPart objects containing the encoded video.
|
||||
"""
|
||||
|
||||
base_64_string = video_to_base64_string(
|
||||
video_input,
|
||||
container_format=VideoContainer.MP4,
|
||||
codec=VideoCodec.H264
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
|
||||
"""Convert video input to Gemini API compatible parts."""
|
||||
|
||||
base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||
return [
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
@@ -232,8 +327,7 @@ class GeminiNode(IO.ComfyNode):
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def create_audio_parts(cls, audio_input: Input.Audio) -> list[GeminiPart]:
|
||||
def create_audio_parts(self, audio_input: IO.AUDIO) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert audio input to Gemini API compatible parts.
|
||||
|
||||
@@ -246,10 +340,10 @@ class GeminiNode(IO.ComfyNode):
|
||||
audio_parts: list[GeminiPart] = []
|
||||
for batch_index in range(audio_input["waveform"].shape[0]):
|
||||
# Recreate an IO.AUDIO object for the given batch dimension index
|
||||
audio_at_index = Input.Audio(
|
||||
waveform=audio_input["waveform"][batch_index].unsqueeze(0),
|
||||
sample_rate=audio_input["sample_rate"],
|
||||
)
|
||||
audio_at_index = {
|
||||
"waveform": audio_input["waveform"][batch_index].unsqueeze(0),
|
||||
"sample_rate": audio_input["sample_rate"],
|
||||
}
|
||||
# Convert to MP3 format for compatibility with Gemini API
|
||||
audio_bytes = audio_to_base64_string(
|
||||
audio_at_index,
|
||||
@@ -266,38 +360,38 @@ class GeminiNode(IO.ComfyNode):
|
||||
)
|
||||
return audio_parts
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
seed: int,
|
||||
images: Optional[torch.Tensor] = None,
|
||||
audio: Optional[Input.Audio] = None,
|
||||
video: Optional[Input.Video] = None,
|
||||
model: GeminiModel,
|
||||
images: Optional[IO.IMAGE] = None,
|
||||
audio: Optional[IO.AUDIO] = None,
|
||||
video: Optional[IO.VIDEO] = None,
|
||||
files: Optional[list[GeminiPart]] = None,
|
||||
) -> IO.NodeOutput:
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[str]:
|
||||
# Validate inputs
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
# Create parts list with text prompt as the first part
|
||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||
parts: list[GeminiPart] = [create_text_part(prompt)]
|
||||
|
||||
# Add other modal parts
|
||||
if images is not None:
|
||||
image_parts = create_image_parts(images)
|
||||
parts.extend(image_parts)
|
||||
if audio is not None:
|
||||
parts.extend(cls.create_audio_parts(audio))
|
||||
parts.extend(self.create_audio_parts(audio))
|
||||
if video is not None:
|
||||
parts.extend(cls.create_video_parts(video))
|
||||
parts.extend(self.create_video_parts(video))
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
# Create response
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
data=GeminiGenerateContentRequest(
|
||||
response = await SynchronousOperation(
|
||||
endpoint=get_gemini_endpoint(model),
|
||||
request=GeminiGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(
|
||||
role="user",
|
||||
@@ -305,15 +399,15 @@ class GeminiNode(IO.ComfyNode):
|
||||
)
|
||||
]
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
)
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
|
||||
# Get result output
|
||||
output_text = get_text_from_response(response)
|
||||
if output_text:
|
||||
if unique_id and output_text:
|
||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
||||
render_spec = {
|
||||
"node_id": cls.hidden.unique_id,
|
||||
"node_id": unique_id,
|
||||
"component": "ChatHistoryWidget",
|
||||
"props": {
|
||||
"history": json.dumps(
|
||||
@@ -333,10 +427,10 @@ class GeminiNode(IO.ComfyNode):
|
||||
render_spec,
|
||||
)
|
||||
|
||||
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
|
||||
return (output_text or "Empty response from Gemini model...",)
|
||||
|
||||
|
||||
class GeminiInputFiles(IO.ComfyNode):
|
||||
class GeminiInputFiles(ComfyNodeABC):
|
||||
"""
|
||||
Loads and formats input files for use with the Gemini API.
|
||||
|
||||
@@ -347,7 +441,7 @@ class GeminiInputFiles(IO.ComfyNode):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
"""
|
||||
For details about the supported file input types, see:
|
||||
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||
@@ -362,37 +456,39 @@ class GeminiInputFiles(IO.ComfyNode):
|
||||
]
|
||||
input_files = sorted(input_files, key=lambda x: x.name)
|
||||
input_files = [f.name for f in input_files]
|
||||
return IO.Schema(
|
||||
node_id="GeminiInputFiles",
|
||||
display_name="Gemini Input Files",
|
||||
category="api node/text/Gemini",
|
||||
description="Loads and prepares input files to include as inputs for Gemini LLM nodes. "
|
||||
"The files will be read by the Gemini model when generating a response. "
|
||||
"The contents of the text file count toward the token limit. "
|
||||
"🛈 TIP: Can be chained together with other Gemini Input File nodes.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"file",
|
||||
options=input_files,
|
||||
default=input_files[0] if input_files else None,
|
||||
tooltip="Input files to include as context for the model. "
|
||||
"Only accepts text (.txt) and PDF (.pdf) files for now.",
|
||||
return {
|
||||
"required": {
|
||||
"file": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.",
|
||||
"options": input_files,
|
||||
"default": input_files[0] if input_files else None,
|
||||
},
|
||||
),
|
||||
IO.Custom("GEMINI_INPUT_FILES").Input(
|
||||
},
|
||||
"optional": {
|
||||
"GEMINI_INPUT_FILES": (
|
||||
"GEMINI_INPUT_FILES",
|
||||
optional=True,
|
||||
tooltip="An optional additional file(s) to batch together with the file loaded from this node. "
|
||||
"Allows chaining of input files so that a single message can include multiple input files.",
|
||||
{
|
||||
"tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.",
|
||||
"default": None,
|
||||
},
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Custom("GEMINI_INPUT_FILES").Output(),
|
||||
],
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_file_part(cls, file_path: str) -> GeminiPart:
|
||||
mime_type = GeminiMimeType.application_pdf if file_path.endswith(".pdf") else GeminiMimeType.text_plain
|
||||
DESCRIPTION = "Loads and prepares input files to include as inputs for Gemini LLM nodes. The files will be read by the Gemini model when generating a response. The contents of the text file count toward the token limit. 🛈 TIP: Can be chained together with other Gemini Input File nodes."
|
||||
RETURN_TYPES = ("GEMINI_INPUT_FILES",)
|
||||
FUNCTION = "prepare_files"
|
||||
CATEGORY = "api node/text/Gemini"
|
||||
|
||||
def create_file_part(self, file_path: str) -> GeminiPart:
|
||||
mime_type = (
|
||||
GeminiMimeType.application_pdf
|
||||
if file_path.endswith(".pdf")
|
||||
else GeminiMimeType.text_plain
|
||||
)
|
||||
# Use base64 string directly, not the data URI
|
||||
with open(file_path, "rb") as f:
|
||||
file_content = f.read()
|
||||
@@ -405,95 +501,120 @@ class GeminiInputFiles(IO.ComfyNode):
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, file: str, GEMINI_INPUT_FILES: Optional[list[GeminiPart]] = None) -> IO.NodeOutput:
|
||||
"""Loads and formats input files for Gemini API."""
|
||||
if GEMINI_INPUT_FILES is None:
|
||||
GEMINI_INPUT_FILES = []
|
||||
def prepare_files(
|
||||
self, file: str, GEMINI_INPUT_FILES: list[GeminiPart] = []
|
||||
) -> tuple[list[GeminiPart]]:
|
||||
"""
|
||||
Loads and formats input files for Gemini API.
|
||||
"""
|
||||
file_path = folder_paths.get_annotated_filepath(file)
|
||||
input_file_content = cls.create_file_part(file_path)
|
||||
return IO.NodeOutput([input_file_content] + GEMINI_INPUT_FILES)
|
||||
input_file_content = self.create_file_part(file_path)
|
||||
files = [input_file_content] + GEMINI_INPUT_FILES
|
||||
return (files,)
|
||||
|
||||
|
||||
class GeminiImage(IO.ComfyNode):
|
||||
class GeminiImage(ComfyNodeABC):
|
||||
"""
|
||||
Node to generate text and image responses from a Gemini model.
|
||||
|
||||
This node allows users to interact with Google's Gemini AI models, providing
|
||||
multimodal inputs (text, images, files) to generate coherent
|
||||
text and image responses. The node works with the latest Gemini models, handling the
|
||||
API communication and response parsing.
|
||||
"""
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GeminiImageNode",
|
||||
display_name="Google Gemini Image",
|
||||
category="api node/image/Gemini",
|
||||
description="Edit images synchronously via Google API.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Text prompt for generation",
|
||||
default="",
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt for generation",
|
||||
},
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=GeminiImageModel,
|
||||
default=GeminiImageModel.gemini_2_5_flash_image,
|
||||
tooltip="The Gemini model to use for generating responses.",
|
||||
"model": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "The Gemini model to use for generating responses.",
|
||||
"options": [model.value for model in GeminiImageModel],
|
||||
"default": GeminiImageModel.gemini_2_5_flash_image.value,
|
||||
},
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=42,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="When seed is fixed to a specific value, the model makes a best effort to provide "
|
||||
"the same response for repeated requests. Deterministic output isn't guaranteed. "
|
||||
"Also, changing the model or parameter settings, such as the temperature, "
|
||||
"can cause variations in the response even when you use the same seed value. "
|
||||
"By default, a random seed value is used.",
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 42,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
|
||||
},
|
||||
),
|
||||
IO.Image.Input(
|
||||
"images",
|
||||
optional=True,
|
||||
tooltip="Optional image(s) to use as context for the model. "
|
||||
"To include multiple images, you can use the Batch Images node.",
|
||||
},
|
||||
"optional": {
|
||||
"images": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
|
||||
},
|
||||
),
|
||||
IO.Custom("GEMINI_INPUT_FILES").Input(
|
||||
"files",
|
||||
optional=True,
|
||||
tooltip="Optional file(s) to use as context for the model. "
|
||||
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
"files": (
|
||||
"GEMINI_INPUT_FILES",
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
},
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
|
||||
default="auto",
|
||||
tooltip="Defaults to matching the output image size to that of your input image, "
|
||||
"or otherwise generates 1:1 squares.",
|
||||
optional=True,
|
||||
# TODO: later we can add this parameter later
|
||||
# "n": (
|
||||
# IO.INT,
|
||||
# {
|
||||
# "default": 1,
|
||||
# "min": 1,
|
||||
# "max": 8,
|
||||
# "step": 1,
|
||||
# "display": "number",
|
||||
# "tooltip": "How many images to generate",
|
||||
# },
|
||||
# ),
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "Defaults to matching the output image size to that of your input image, or otherwise generates 1:1 squares.",
|
||||
"options": ["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
|
||||
"default": "auto",
|
||||
},
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
IO.String.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
RETURN_TYPES = (IO.IMAGE, IO.STRING)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Gemini"
|
||||
DESCRIPTION = "Edit images synchronously via Google API."
|
||||
API_NODE = True
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
seed: int,
|
||||
images: Optional[torch.Tensor] = None,
|
||||
model: GeminiImageModel,
|
||||
images: Optional[IO.IMAGE] = None,
|
||||
files: Optional[list[GeminiPart]] = None,
|
||||
n=1,
|
||||
aspect_ratio: str = "auto",
|
||||
) -> IO.NodeOutput:
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||
parts: list[GeminiPart] = [create_text_part(prompt)]
|
||||
|
||||
if not aspect_ratio:
|
||||
aspect_ratio = "auto" # for backward compatability with old workflows; to-do remove this in December
|
||||
@@ -505,27 +626,29 @@ class GeminiImage(IO.ComfyNode):
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
data=GeminiImageGenerateContentRequest(
|
||||
response = await SynchronousOperation(
|
||||
endpoint=get_gemini_image_endpoint(model),
|
||||
request=GeminiImageGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(role="user", parts=parts),
|
||||
GeminiContent(
|
||||
role="user",
|
||||
parts=parts,
|
||||
),
|
||||
],
|
||||
generationConfig=GeminiImageGenerationConfig(
|
||||
responseModalities=["TEXT", "IMAGE"],
|
||||
responseModalities=["TEXT","IMAGE"],
|
||||
imageConfig=None if aspect_ratio == "auto" else image_config,
|
||||
),
|
||||
)
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
)
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
|
||||
output_image = get_image_from_response(response)
|
||||
output_text = get_text_from_response(response)
|
||||
if output_text:
|
||||
if unique_id and output_text:
|
||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
||||
render_spec = {
|
||||
"node_id": cls.hidden.unique_id,
|
||||
"node_id": unique_id,
|
||||
"component": "ChatHistoryWidget",
|
||||
"props": {
|
||||
"history": json.dumps(
|
||||
@@ -546,18 +669,17 @@ class GeminiImage(IO.ComfyNode):
|
||||
)
|
||||
|
||||
output_text = output_text or "Empty response from Gemini model..."
|
||||
return IO.NodeOutput(output_image, output_text)
|
||||
return (output_image, output_text,)
|
||||
|
||||
|
||||
class GeminiExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
GeminiNode,
|
||||
GeminiImage,
|
||||
GeminiInputFiles,
|
||||
]
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"GeminiNode": GeminiNode,
|
||||
"GeminiImageNode": GeminiImage,
|
||||
"GeminiInputFiles": GeminiInputFiles,
|
||||
}
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> GeminiExtension:
|
||||
return GeminiExtension()
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"GeminiNode": "Google Gemini",
|
||||
"GeminiImageNode": "Google Gemini Image",
|
||||
"GeminiInputFiles": "Gemini Input Files",
|
||||
}
|
||||
|
||||
@@ -5,7 +5,8 @@ For source of truth on the allowed permutations of request fields, please refere
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional, TypeVar
|
||||
from typing import Optional, TypeVar, Any
|
||||
from collections.abc import Callable
|
||||
import math
|
||||
import logging
|
||||
|
||||
@@ -14,6 +15,7 @@ from typing_extensions import override
|
||||
import torch
|
||||
|
||||
from comfy_api_nodes.apis import (
|
||||
KlingTaskStatus,
|
||||
KlingCameraControl,
|
||||
KlingCameraConfig,
|
||||
KlingCameraControlType,
|
||||
@@ -50,20 +52,26 @@ from comfy_api_nodes.apis import (
|
||||
KlingCharacterEffectModelName,
|
||||
KlingSingleImageEffectModelName,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
tensor_to_base64_string,
|
||||
download_url_to_video_output,
|
||||
upload_video_to_comfyapi,
|
||||
upload_audio_to_comfyapi,
|
||||
download_url_to_image_tensor,
|
||||
validate_string,
|
||||
)
|
||||
from comfy_api_nodes.util.validation_utils import (
|
||||
validate_image_dimensions,
|
||||
validate_image_aspect_ratio,
|
||||
validate_video_dimensions,
|
||||
validate_video_duration,
|
||||
tensor_to_base64_string,
|
||||
validate_string,
|
||||
upload_audio_to_comfyapi,
|
||||
download_url_to_image_tensor,
|
||||
upload_video_to_comfyapi,
|
||||
download_url_to_video_output,
|
||||
sync_op,
|
||||
ApiEndpoint,
|
||||
poll_op,
|
||||
)
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.input.basic_types import AudioInput
|
||||
@@ -206,6 +214,34 @@ VOICES_CONFIG = {
|
||||
}
|
||||
|
||||
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, R],
|
||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||
estimated_duration: Optional[int] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> R:
|
||||
"""Polls the Kling API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
KlingTaskStatus.succeed.value,
|
||||
],
|
||||
failed_statuses=[KlingTaskStatus.failed.value],
|
||||
status_extractor=lambda response: (
|
||||
response.data.task_status.value
|
||||
if response.data and response.data.task_status
|
||||
else None
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=result_url_extractor,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
poll_interval=16.0,
|
||||
max_poll_attempts=256,
|
||||
).execute()
|
||||
|
||||
|
||||
def is_valid_camera_control_configs(configs: list[float]) -> bool:
|
||||
"""Verifies that at least one camera control configuration is non-zero."""
|
||||
return any(not math.isclose(value, 0.0) for value in configs)
|
||||
@@ -341,7 +377,8 @@ async def image_result_to_node_output(
|
||||
|
||||
|
||||
async def execute_text2video(
|
||||
cls: type[IO.ComfyNode],
|
||||
auth_kwargs: dict[str, str],
|
||||
node_id: str,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
cfg_scale: float,
|
||||
@@ -352,11 +389,14 @@ async def execute_text2video(
|
||||
camera_control: Optional[KlingCameraControl] = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
|
||||
response_model=KlingText2VideoResponse,
|
||||
data=KlingText2VideoRequest(
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_TEXT_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=KlingText2VideoRequest,
|
||||
response_model=KlingText2VideoResponse,
|
||||
),
|
||||
request=KlingText2VideoRequest(
|
||||
prompt=prompt if prompt else None,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
duration=KlingVideoGenDuration(duration),
|
||||
@@ -366,17 +406,24 @@ async def execute_text2video(
|
||||
aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio),
|
||||
camera_control=camera_control,
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
|
||||
task_id = task_creation_response.data.task_id
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_TEXT_TO_VIDEO}/{task_id}"),
|
||||
response_model=KlingText2VideoResponse,
|
||||
final_response = await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_TEXT_TO_VIDEO}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingText2VideoResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_T2V,
|
||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||
node_id=node_id,
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
@@ -385,7 +432,8 @@ async def execute_text2video(
|
||||
|
||||
|
||||
async def execute_image2video(
|
||||
cls: type[IO.ComfyNode],
|
||||
auth_kwargs: dict[str, str],
|
||||
node_id: str,
|
||||
start_frame: torch.Tensor,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
@@ -407,11 +455,14 @@ async def execute_image2video(
|
||||
if model_mode == "std" and model_name == KlingVideoGenModelName.kling_v2_5_turbo.value:
|
||||
model_mode = "pro" # October 5: currently "std" mode is not supported for this model
|
||||
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||
response_model=KlingImage2VideoResponse,
|
||||
data=KlingImage2VideoRequest(
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_IMAGE_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=KlingImage2VideoRequest,
|
||||
response_model=KlingImage2VideoResponse,
|
||||
),
|
||||
request=KlingImage2VideoRequest(
|
||||
model_name=KlingVideoGenModelName(model_name),
|
||||
image=tensor_to_base64_string(start_frame),
|
||||
image_tail=(
|
||||
@@ -426,17 +477,24 @@ async def execute_image2video(
|
||||
duration=KlingVideoGenDuration(duration),
|
||||
camera_control=camera_control,
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
|
||||
response_model=KlingImage2VideoResponse,
|
||||
final_response = await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=KlingImage2VideoRequest,
|
||||
response_model=KlingImage2VideoResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_I2V,
|
||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||
node_id=node_id,
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
@@ -445,7 +503,8 @@ async def execute_image2video(
|
||||
|
||||
|
||||
async def execute_video_effect(
|
||||
cls: type[IO.ComfyNode],
|
||||
auth_kwargs: dict[str, str],
|
||||
node_id: str,
|
||||
dual_character: bool,
|
||||
effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene,
|
||||
model_name: str,
|
||||
@@ -471,25 +530,35 @@ async def execute_video_effect(
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=PATH_VIDEO_EFFECTS, method="POST"),
|
||||
response_model=KlingVideoEffectsResponse,
|
||||
data=KlingVideoEffectsRequest(
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_VIDEO_EFFECTS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=KlingVideoEffectsRequest,
|
||||
response_model=KlingVideoEffectsResponse,
|
||||
),
|
||||
request=KlingVideoEffectsRequest(
|
||||
effect_scene=effect_scene,
|
||||
input=request_input_field,
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_VIDEO_EFFECTS}/{task_id}"),
|
||||
response_model=KlingVideoEffectsResponse,
|
||||
final_response = await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_VIDEO_EFFECTS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingVideoEffectsResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS,
|
||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||
node_id=node_id,
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
@@ -498,7 +567,8 @@ async def execute_video_effect(
|
||||
|
||||
|
||||
async def execute_lipsync(
|
||||
cls: type[IO.ComfyNode],
|
||||
auth_kwargs: dict[str, str],
|
||||
node_id: str,
|
||||
video: VideoInput,
|
||||
audio: Optional[AudioInput] = None,
|
||||
voice_language: Optional[str] = None,
|
||||
@@ -513,21 +583,24 @@ async def execute_lipsync(
|
||||
validate_video_duration(video, 2, 10)
|
||||
|
||||
# Upload video to Comfy API and get download URL
|
||||
video_url = await upload_video_to_comfyapi(cls, video)
|
||||
video_url = await upload_video_to_comfyapi(video, auth_kwargs=auth_kwargs)
|
||||
logging.info("Uploaded video to Comfy API. URL: %s", video_url)
|
||||
|
||||
# Upload the audio file to Comfy API and get download URL
|
||||
if audio:
|
||||
audio_url = await upload_audio_to_comfyapi(cls, audio)
|
||||
audio_url = await upload_audio_to_comfyapi(audio, auth_kwargs=auth_kwargs)
|
||||
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
|
||||
else:
|
||||
audio_url = None
|
||||
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(PATH_LIP_SYNC, "POST"),
|
||||
response_model=KlingLipSyncResponse,
|
||||
data=KlingLipSyncRequest(
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_LIP_SYNC,
|
||||
method=HttpMethod.POST,
|
||||
request_model=KlingLipSyncRequest,
|
||||
response_model=KlingLipSyncResponse,
|
||||
),
|
||||
request=KlingLipSyncRequest(
|
||||
input=KlingLipSyncInputObject(
|
||||
video_url=video_url,
|
||||
mode=model_mode,
|
||||
@@ -539,17 +612,24 @@ async def execute_lipsync(
|
||||
voice_id=voice_id,
|
||||
),
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_LIP_SYNC}/{task_id}"),
|
||||
response_model=KlingLipSyncResponse,
|
||||
final_response = await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_LIP_SYNC}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingLipSyncResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_LIP_SYNC,
|
||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||
node_id=node_id,
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
@@ -727,7 +807,11 @@ class KlingTextToVideoNode(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
model_mode, duration, model_name = MODE_TEXT2VIDEO[mode]
|
||||
return await execute_text2video(
|
||||
cls,
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
cfg_scale=cfg_scale,
|
||||
@@ -788,7 +872,11 @@ class KlingCameraControlT2VNode(IO.ComfyNode):
|
||||
camera_control: Optional[KlingCameraControl] = None,
|
||||
) -> IO.NodeOutput:
|
||||
return await execute_text2video(
|
||||
cls,
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
model_name=KlingVideoGenModelName.kling_v1,
|
||||
cfg_scale=cfg_scale,
|
||||
model_mode=KlingVideoGenMode.std,
|
||||
@@ -856,7 +944,11 @@ class KlingImage2VideoNode(IO.ComfyNode):
|
||||
end_frame: Optional[torch.Tensor] = None,
|
||||
) -> IO.NodeOutput:
|
||||
return await execute_image2video(
|
||||
cls,
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
start_frame=start_frame,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
@@ -925,7 +1017,11 @@ class KlingCameraControlI2VNode(IO.ComfyNode):
|
||||
camera_control: KlingCameraControl,
|
||||
) -> IO.NodeOutput:
|
||||
return await execute_image2video(
|
||||
cls,
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
model_name=KlingVideoGenModelName.kling_v1_5,
|
||||
start_frame=start_frame,
|
||||
cfg_scale=cfg_scale,
|
||||
@@ -1001,7 +1097,11 @@ class KlingStartEndFrameNode(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
mode, duration, model_name = MODE_START_END_FRAME[mode]
|
||||
return await execute_image2video(
|
||||
cls,
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
model_name=model_name,
|
||||
@@ -1062,27 +1162,41 @@ class KlingVideoExtendNode(IO.ComfyNode):
|
||||
video_id: str,
|
||||
) -> IO.NodeOutput:
|
||||
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_VIDEO_EXTEND, method="POST"),
|
||||
response_model=KlingVideoExtendResponse,
|
||||
data=KlingVideoExtendRequest(
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_VIDEO_EXTEND,
|
||||
method=HttpMethod.POST,
|
||||
request_model=KlingVideoExtendRequest,
|
||||
response_model=KlingVideoExtendResponse,
|
||||
),
|
||||
request=KlingVideoExtendRequest(
|
||||
prompt=prompt if prompt else None,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
cfg_scale=cfg_scale,
|
||||
video_id=video_id,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_VIDEO_EXTEND}/{task_id}"),
|
||||
response_model=KlingVideoExtendResponse,
|
||||
final_response = await poll_until_finished(
|
||||
auth,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_VIDEO_EXTEND}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingVideoExtendResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND,
|
||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||
node_id=cls.hidden.unique_id,
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
@@ -1145,7 +1259,11 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode):
|
||||
duration: KlingVideoGenDuration,
|
||||
) -> IO.NodeOutput:
|
||||
video, _, duration = await execute_video_effect(
|
||||
cls,
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
dual_character=True,
|
||||
effect_scene=effect_scene,
|
||||
model_name=model_name,
|
||||
@@ -1206,7 +1324,11 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(
|
||||
*(
|
||||
await execute_video_effect(
|
||||
cls,
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
dual_character=False,
|
||||
effect_scene=effect_scene,
|
||||
model_name=model_name,
|
||||
@@ -1257,7 +1379,11 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode):
|
||||
voice_language: str,
|
||||
) -> IO.NodeOutput:
|
||||
return await execute_lipsync(
|
||||
cls,
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
video=video,
|
||||
audio=audio,
|
||||
voice_language=voice_language,
|
||||
@@ -1319,7 +1445,11 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
voice_id, voice_language = VOICES_CONFIG[voice]
|
||||
return await execute_lipsync(
|
||||
cls,
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
video=video,
|
||||
text=text,
|
||||
voice_language=voice_language,
|
||||
@@ -1366,26 +1496,40 @@ class KlingVirtualTryOnNode(IO.ComfyNode):
|
||||
cloth_image: torch.Tensor,
|
||||
model_name: KlingVirtualTryOnModelName,
|
||||
) -> IO.NodeOutput:
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_VIRTUAL_TRY_ON, method="POST"),
|
||||
response_model=KlingVirtualTryOnResponse,
|
||||
data=KlingVirtualTryOnRequest(
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_VIRTUAL_TRY_ON,
|
||||
method=HttpMethod.POST,
|
||||
request_model=KlingVirtualTryOnRequest,
|
||||
response_model=KlingVirtualTryOnResponse,
|
||||
),
|
||||
request=KlingVirtualTryOnRequest(
|
||||
human_image=tensor_to_base64_string(human_image),
|
||||
cloth_image=tensor_to_base64_string(cloth_image),
|
||||
model_name=model_name,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}"),
|
||||
response_model=KlingVirtualTryOnResponse,
|
||||
final_response = await poll_until_finished(
|
||||
auth,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingVirtualTryOnResponse,
|
||||
),
|
||||
result_url_extractor=get_images_urls_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON,
|
||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||
node_id=cls.hidden.unique_id,
|
||||
)
|
||||
validate_image_result_response(final_response)
|
||||
|
||||
@@ -1481,11 +1625,18 @@ class KlingImageGenerationNode(IO.ComfyNode):
|
||||
else:
|
||||
image = tensor_to_base64_string(image)
|
||||
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_IMAGE_GENERATIONS, method="POST"),
|
||||
response_model=KlingImageGenerationsResponse,
|
||||
data=KlingImageGenerationsRequest(
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_IMAGE_GENERATIONS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=KlingImageGenerationsRequest,
|
||||
response_model=KlingImageGenerationsResponse,
|
||||
),
|
||||
request=KlingImageGenerationsRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
@@ -1496,17 +1647,24 @@ class KlingImageGenerationNode(IO.ComfyNode):
|
||||
n=n,
|
||||
aspect_ratio=aspect_ratio,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_IMAGE_GENERATIONS}/{task_id}"),
|
||||
response_model=KlingImageGenerationsResponse,
|
||||
final_response = await poll_until_finished(
|
||||
auth,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_IMAGE_GENERATIONS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingImageGenerationsResponse,
|
||||
),
|
||||
result_url_extractor=get_images_urls_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_IMAGE_GEN,
|
||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||
node_id=cls.hidden.unique_id,
|
||||
)
|
||||
validate_image_result_response(final_response)
|
||||
|
||||
|
||||
@@ -1,191 +0,0 @@
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
get_number_of_images,
|
||||
sync_op_raw,
|
||||
upload_images_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
MODELS_MAP = {
|
||||
"LTX-2 (Pro)": "ltx-2-pro",
|
||||
"LTX-2 (Fast)": "ltx-2-fast",
|
||||
}
|
||||
|
||||
|
||||
class ExecuteTaskRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
model: str = Field(...)
|
||||
duration: int = Field(...)
|
||||
resolution: str = Field(...)
|
||||
fps: Optional[int] = Field(25)
|
||||
generate_audio: Optional[bool] = Field(True)
|
||||
image_uri: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class TextToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="LtxvApiTextToVideo",
|
||||
display_name="LTXV Text To Video",
|
||||
category="api node/video/LTXV",
|
||||
description="Professional-quality videos with customizable duration and resolution.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
),
|
||||
IO.Combo.Input("duration", options=[6, 8, 10], default=8),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[
|
||||
"1920x1080",
|
||||
"2560x1440",
|
||||
"3840x2160",
|
||||
],
|
||||
),
|
||||
IO.Combo.Input("fps", options=[25, 50], default=25),
|
||||
IO.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=False,
|
||||
optional=True,
|
||||
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
resolution: str,
|
||||
fps: int = 25,
|
||||
generate_audio: bool = False,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=10000)
|
||||
response = await sync_op_raw(
|
||||
cls,
|
||||
ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"),
|
||||
data=ExecuteTaskRequest(
|
||||
prompt=prompt,
|
||||
model=MODELS_MAP[model],
|
||||
duration=duration,
|
||||
resolution=resolution,
|
||||
fps=fps,
|
||||
generate_audio=generate_audio,
|
||||
),
|
||||
as_binary=True,
|
||||
max_retries=1,
|
||||
)
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||
|
||||
|
||||
class ImageToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="LtxvApiImageToVideo",
|
||||
display_name="LTXV Image To Video",
|
||||
category="api node/video/LTXV",
|
||||
description="Professional-quality videos with customizable duration and resolution based on start image.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="First frame to be used for the video."),
|
||||
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
),
|
||||
IO.Combo.Input("duration", options=[6, 8, 10], default=8),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[
|
||||
"1920x1080",
|
||||
"2560x1440",
|
||||
"3840x2160",
|
||||
],
|
||||
),
|
||||
IO.Combo.Input("fps", options=[25, 50], default=25),
|
||||
IO.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=False,
|
||||
optional=True,
|
||||
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
model: str,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
resolution: str,
|
||||
fps: int = 25,
|
||||
generate_audio: bool = False,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=10000)
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Currently only one input image is supported.")
|
||||
response = await sync_op_raw(
|
||||
cls,
|
||||
ApiEndpoint("/proxy/ltx/v1/image-to-video", "POST"),
|
||||
data=ExecuteTaskRequest(
|
||||
image_uri=(await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0],
|
||||
prompt=prompt,
|
||||
model=MODELS_MAP[model],
|
||||
duration=duration,
|
||||
resolution=resolution,
|
||||
fps=fps,
|
||||
generate_audio=generate_audio,
|
||||
),
|
||||
as_binary=True,
|
||||
max_retries=1,
|
||||
)
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||
|
||||
|
||||
class LtxvApiExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
TextToVideoNode,
|
||||
ImageToVideoNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> LtxvApiExtension:
|
||||
return LtxvApiExtension()
|
||||
@@ -35,9 +35,9 @@ from comfy_api_nodes.apis.client import (
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
upload_images_to_comfyapi,
|
||||
process_image_response,
|
||||
validate_string,
|
||||
)
|
||||
from server import PromptServer
|
||||
from comfy_api_nodes.util import validate_string
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
|
||||
@@ -24,8 +24,8 @@ from comfy_api_nodes.apis.client import (
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_bytesio,
|
||||
upload_images_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
from comfy_api_nodes.util import validate_string
|
||||
from server import PromptServer
|
||||
|
||||
|
||||
|
||||
@@ -1,31 +1,35 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
from comfy_api_nodes.util.validation_utils import validate_image_dimensions
|
||||
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis import (
|
||||
MoonvalleyPromptResponse,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
MoonvalleyTextToVideoRequest,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
MoonvalleyVideoToVideoInferenceParams,
|
||||
MoonvalleyVideoToVideoRequest,
|
||||
MoonvalleyPromptResponse,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
trim_video,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_container_format_is_mp4,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.latest import ComfyExtension, InputImpl, IO
|
||||
import av
|
||||
import io
|
||||
|
||||
API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads"
|
||||
API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts"
|
||||
API_VIDEO2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/video-to-video"
|
||||
@@ -47,6 +51,13 @@ MAX_VID_HEIGHT = 10000
|
||||
MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing
|
||||
|
||||
MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class MoonvalleyApiError(Exception):
|
||||
"""Base exception for Moonvalley API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool:
|
||||
@@ -58,7 +69,64 @@ def validate_task_creation_response(response) -> None:
|
||||
if not is_valid_task_creation_response(response):
|
||||
error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
raise MoonvalleyApiError(error_msg)
|
||||
|
||||
|
||||
def get_video_from_response(response):
|
||||
video = response.output_url
|
||||
logging.info(
|
||||
"Moonvalley Marey API: Task %s succeeded. Video URL: %s", response.id, video
|
||||
)
|
||||
return video
|
||||
|
||||
|
||||
def get_video_url_from_response(response) -> Optional[str]:
|
||||
"""Returns the first video url from the Moonvalley video generation task result.
|
||||
Will not raise an error if the response is not valid.
|
||||
"""
|
||||
if response:
|
||||
return str(get_video_from_response(response))
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, R],
|
||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> R:
|
||||
"""Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
"completed",
|
||||
],
|
||||
max_poll_attempts=240, # 64 minutes with 16s interval
|
||||
poll_interval=16.0,
|
||||
failed_statuses=["error"],
|
||||
status_extractor=lambda response: (
|
||||
response.status if response and response.status else None
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=result_url_extractor,
|
||||
node_id=node_id,
|
||||
).execute()
|
||||
|
||||
|
||||
def validate_prompts(
|
||||
prompt: str, negative_prompt: str, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH
|
||||
):
|
||||
"""Verifies that the prompt isn't empty and that neither prompt is too long."""
|
||||
if not prompt:
|
||||
raise ValueError("Positive prompt is empty")
|
||||
if len(prompt) > max_length:
|
||||
raise ValueError(f"Positive prompt is too long: {len(prompt)} characters")
|
||||
if negative_prompt and len(negative_prompt) > max_length:
|
||||
raise ValueError(
|
||||
f"Negative prompt is too long: {len(negative_prompt)} characters"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
||||
@@ -102,8 +170,12 @@ def _validate_video_dimensions(width: int, height: int) -> None:
|
||||
}
|
||||
|
||||
if (width, height) not in supported_resolutions:
|
||||
supported_list = ", ".join([f"{w}x{h}" for w, h in sorted(supported_resolutions)])
|
||||
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
|
||||
supported_list = ", ".join(
|
||||
[f"{w}x{h}" for w, h in sorted(supported_resolutions)]
|
||||
)
|
||||
raise ValueError(
|
||||
f"Resolution {width}x{height} not supported. Supported: {supported_list}"
|
||||
)
|
||||
|
||||
|
||||
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
||||
@@ -116,7 +188,7 @@ def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
||||
def _validate_minimum_duration(duration: float) -> None:
|
||||
"""Ensures video is at least 5 seconds long."""
|
||||
if duration < 5:
|
||||
raise ValueError("Input video must be at least 5 seconds long.")
|
||||
raise MoonvalleyApiError("Input video must be at least 5 seconds long.")
|
||||
|
||||
|
||||
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
||||
@@ -126,6 +198,123 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
||||
return video
|
||||
|
||||
|
||||
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
||||
"""
|
||||
Returns a new VideoInput object trimmed from the beginning to the specified duration,
|
||||
using av to avoid loading entire video into memory.
|
||||
|
||||
Args:
|
||||
video: Input video to trim
|
||||
duration_sec: Duration in seconds to keep from the beginning
|
||||
|
||||
Returns:
|
||||
VideoFromFile object that owns the output buffer
|
||||
"""
|
||||
output_buffer = io.BytesIO()
|
||||
|
||||
input_container = None
|
||||
output_container = None
|
||||
|
||||
try:
|
||||
# Get the stream source - this avoids loading entire video into memory
|
||||
# when the source is already a file path
|
||||
input_source = video.get_stream_source()
|
||||
|
||||
# Open containers
|
||||
input_container = av.open(input_source, mode="r")
|
||||
output_container = av.open(output_buffer, mode="w", format="mp4")
|
||||
|
||||
# Set up output streams for re-encoding
|
||||
video_stream = None
|
||||
audio_stream = None
|
||||
|
||||
for stream in input_container.streams:
|
||||
logging.info("Found stream: type=%s, class=%s", stream.type, type(stream))
|
||||
if isinstance(stream, av.VideoStream):
|
||||
# Create output video stream with same parameters
|
||||
video_stream = output_container.add_stream(
|
||||
"h264", rate=stream.average_rate
|
||||
)
|
||||
video_stream.width = stream.width
|
||||
video_stream.height = stream.height
|
||||
video_stream.pix_fmt = "yuv420p"
|
||||
logging.info(
|
||||
"Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate
|
||||
)
|
||||
elif isinstance(stream, av.AudioStream):
|
||||
# Create output audio stream with same parameters
|
||||
audio_stream = output_container.add_stream(
|
||||
"aac", rate=stream.sample_rate
|
||||
)
|
||||
audio_stream.sample_rate = stream.sample_rate
|
||||
audio_stream.layout = stream.layout
|
||||
logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels)
|
||||
|
||||
# Calculate target frame count that's divisible by 16
|
||||
fps = input_container.streams.video[0].average_rate
|
||||
estimated_frames = int(duration_sec * fps)
|
||||
target_frames = (
|
||||
estimated_frames // 16
|
||||
) * 16 # Round down to nearest multiple of 16
|
||||
|
||||
if target_frames == 0:
|
||||
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
|
||||
|
||||
frame_count = 0
|
||||
audio_frame_count = 0
|
||||
|
||||
# Decode and re-encode video frames
|
||||
if video_stream:
|
||||
for frame in input_container.decode(video=0):
|
||||
if frame_count >= target_frames:
|
||||
break
|
||||
|
||||
# Re-encode frame
|
||||
for packet in video_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
for packet in video_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames)
|
||||
|
||||
# Decode and re-encode audio frames
|
||||
if audio_stream:
|
||||
input_container.seek(0) # Reset to beginning for audio
|
||||
for frame in input_container.decode(audio=0):
|
||||
if frame.time >= duration_sec:
|
||||
break
|
||||
|
||||
# Re-encode frame
|
||||
for packet in audio_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
audio_frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
for packet in audio_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
logging.info("Encoded %s audio frames", audio_frame_count)
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
input_container.close()
|
||||
|
||||
# Return as VideoFromFile using the buffer
|
||||
output_buffer.seek(0)
|
||||
return InputImpl.VideoFromFile(output_buffer)
|
||||
|
||||
except Exception as e:
|
||||
# Clean up on error
|
||||
if input_container is not None:
|
||||
input_container.close()
|
||||
if output_container is not None:
|
||||
output_container.close()
|
||||
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
|
||||
|
||||
|
||||
def parse_width_height_from_res(resolution: str):
|
||||
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
|
||||
res_map = {
|
||||
@@ -149,14 +338,19 @@ def parse_control_parameter(value):
|
||||
return control_map.get(value, control_map["Motion Transfer"])
|
||||
|
||||
|
||||
async def get_response(cls: type[IO.ComfyNode], task_id: str) -> MoonvalleyPromptResponse:
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{API_PROMPTS_ENDPOINT}/{task_id}"),
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
status_extractor=lambda r: (r.status if r and r.status else None),
|
||||
poll_interval=16.0,
|
||||
max_poll_attempts=240,
|
||||
async def get_response(
|
||||
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> MoonvalleyPromptResponse:
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -250,10 +444,14 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
|
||||
steps: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
|
||||
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
width_height = parse_width_height_from_res(resolution)
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
inference_params = MoonvalleyTextToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
steps=steps,
|
||||
@@ -266,17 +464,33 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
|
||||
|
||||
# Get MIME type from tensor - assuming PNG format for image tensors
|
||||
mime_type = "image/png"
|
||||
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0]
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, method="POST"),
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
data=MoonvalleyTextToVideoRequest(
|
||||
image_url=image_url, prompt_text=prompt, inference_params=inference_params
|
||||
),
|
||||
|
||||
image_url = (
|
||||
await upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_kwargs=auth, mime_type=mime_type
|
||||
)
|
||||
)[0]
|
||||
|
||||
request = MoonvalleyTextToVideoRequest(
|
||||
image_url=image_url, prompt_text=prompt, inference_params=inference_params
|
||||
)
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=API_IMG2VIDEO_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=MoonvalleyTextToVideoRequest,
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
final_response = await get_response(cls, task_creation_response.id)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = await get_response(
|
||||
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
|
||||
)
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
return IO.NodeOutput(video)
|
||||
|
||||
@@ -368,10 +582,15 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
||||
steps=33,
|
||||
prompt_adherence=4.5,
|
||||
) -> IO.NodeOutput:
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
validated_video = validate_video_to_video_input(video)
|
||||
video_url = await upload_video_to_comfyapi(cls, validated_video)
|
||||
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth)
|
||||
|
||||
validate_prompts(prompt, negative_prompt)
|
||||
|
||||
# Only include motion_intensity for Motion Transfer
|
||||
control_params = {}
|
||||
@@ -386,20 +605,35 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
||||
guidance_scale=prompt_adherence,
|
||||
)
|
||||
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT, method="POST"),
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
data=MoonvalleyVideoToVideoRequest(
|
||||
control_type=parse_control_parameter(control_type),
|
||||
video_url=video_url,
|
||||
prompt_text=prompt,
|
||||
inference_params=inference_params,
|
||||
),
|
||||
control = parse_control_parameter(control_type)
|
||||
|
||||
request = MoonvalleyVideoToVideoRequest(
|
||||
control_type=control,
|
||||
video_url=video_url,
|
||||
prompt_text=prompt,
|
||||
inference_params=inference_params,
|
||||
)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=API_VIDEO2VIDEO_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=MoonvalleyVideoToVideoRequest,
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
final_response = await get_response(cls, task_creation_response.id)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = await get_response(
|
||||
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
|
||||
)
|
||||
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
return IO.NodeOutput(video)
|
||||
|
||||
|
||||
class MoonvalleyTxt2VideoNode(IO.ComfyNode):
|
||||
@@ -486,10 +720,14 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode):
|
||||
seed: int,
|
||||
steps: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
width_height = parse_width_height_from_res(resolution)
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
inference_params = MoonvalleyTextToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
steps=steps,
|
||||
@@ -499,16 +737,30 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode):
|
||||
width=width_height["width"],
|
||||
height=width_height["height"],
|
||||
)
|
||||
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT, method="POST"),
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
data=MoonvalleyTextToVideoRequest(prompt_text=prompt, inference_params=inference_params),
|
||||
request = MoonvalleyTextToVideoRequest(
|
||||
prompt_text=prompt, inference_params=inference_params
|
||||
)
|
||||
|
||||
init_op = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=API_TXT2VIDEO_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=MoonvalleyTextToVideoRequest,
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
task_creation_response = await init_op.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
final_response = await get_response(cls, task_creation_response.id)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = await get_response(
|
||||
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
|
||||
)
|
||||
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
return IO.NodeOutput(video)
|
||||
|
||||
|
||||
class MoonvalleyExtension(ComfyExtension):
|
||||
|
||||
@@ -43,11 +43,13 @@ from comfy_api_nodes.apis.client import (
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
downscale_image_tensor,
|
||||
validate_and_cast_response,
|
||||
validate_string,
|
||||
tensor_to_base64_string,
|
||||
text_filepath_to_data_uri,
|
||||
)
|
||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||
from comfy_api_nodes.util import downscale_image_tensor, validate_string, tensor_to_base64_string
|
||||
|
||||
|
||||
RESPONSES_ENDPOINT = "/proxy/openai/v1/responses"
|
||||
|
||||
@@ -14,6 +14,11 @@ import torch
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_video_output,
|
||||
tensor_to_bytesio,
|
||||
validate_string,
|
||||
)
|
||||
from comfy_api_nodes.apis import pika_defs
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
@@ -22,7 +27,6 @@ from comfy_api_nodes.apis.client import (
|
||||
PollingOperation,
|
||||
SynchronousOperation,
|
||||
)
|
||||
from comfy_api_nodes.util import validate_string, download_url_to_video_output, tensor_to_bytesio
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
@@ -24,7 +24,10 @@ from comfy_api_nodes.apis.client import (
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import validate_string, tensor_to_bytesio
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
tensor_to_bytesio,
|
||||
validate_string,
|
||||
)
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
|
||||
@@ -47,6 +50,7 @@ def get_video_url_from_response(
|
||||
|
||||
async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
||||
# first, upload image to Pixverse and get image id to use in actual generation call
|
||||
files = {"image": tensor_to_bytesio(image)}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/pixverse/image/upload",
|
||||
@@ -55,14 +59,16 @@ async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
||||
response_model=PixverseImageUploadResponse,
|
||||
),
|
||||
request=EmptyRequest(),
|
||||
files={"image": tensor_to_bytesio(image)},
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_upload: PixverseImageUploadResponse = await operation.execute()
|
||||
|
||||
if response_upload.Resp is None:
|
||||
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
|
||||
raise Exception(
|
||||
f"PixVerse image upload request failed: '{response_upload.ErrMsg}'"
|
||||
)
|
||||
|
||||
return response_upload.Resp.img_id
|
||||
|
||||
@@ -89,6 +95,7 @@ class PixverseTemplateNode(IO.ComfyNode):
|
||||
template_id = pixverse_templates.get(template, None)
|
||||
if template_id is None:
|
||||
raise Exception(f"Template '{template}' is not recognized.")
|
||||
# just return the integer
|
||||
return IO.NodeOutput(template_id)
|
||||
|
||||
|
||||
|
||||
@@ -24,10 +24,12 @@ from comfy_api_nodes.apis.client import (
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
bytesio_to_image_tensor,
|
||||
download_url_to_bytesio,
|
||||
tensor_to_bytesio,
|
||||
resize_mask_to_image,
|
||||
validate_string,
|
||||
)
|
||||
from comfy_api_nodes.util import validate_string, tensor_to_bytesio, bytesio_to_image_tensor
|
||||
from server import PromptServer
|
||||
|
||||
import torch
|
||||
|
||||
@@ -11,7 +11,7 @@ User Guides:
|
||||
|
||||
"""
|
||||
|
||||
from typing import Union, Optional
|
||||
from typing import Union, Optional, Any
|
||||
from typing_extensions import override
|
||||
from enum import Enum
|
||||
|
||||
@@ -21,6 +21,7 @@ from comfy_api_nodes.apis import (
|
||||
RunwayImageToVideoRequest,
|
||||
RunwayImageToVideoResponse,
|
||||
RunwayTaskStatusResponse as TaskStatusResponse,
|
||||
RunwayTaskStatusEnum as TaskStatus,
|
||||
RunwayModelEnum as Model,
|
||||
RunwayDurationEnum as Duration,
|
||||
RunwayAspectRatioEnum as AspectRatio,
|
||||
@@ -32,20 +33,23 @@ from comfy_api_nodes.apis import (
|
||||
ReferenceImage,
|
||||
RunwayTextToImageAspectRatioEnum,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
image_tensor_pair_to_batch,
|
||||
validate_string,
|
||||
validate_image_dimensions,
|
||||
validate_image_aspect_ratio,
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
upload_images_to_comfyapi,
|
||||
download_url_to_video_output,
|
||||
image_tensor_pair_to_batch,
|
||||
validate_string,
|
||||
download_url_to_image_tensor,
|
||||
ApiEndpoint,
|
||||
sync_op,
|
||||
poll_op,
|
||||
)
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api_nodes.util.validation_utils import validate_image_dimensions, validate_image_aspect_ratio
|
||||
|
||||
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
||||
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
||||
@@ -87,6 +91,31 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
|
||||
return None
|
||||
|
||||
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
|
||||
estimated_duration: Optional[int] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> TaskStatusResponse:
|
||||
"""Polls the Runway API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
TaskStatus.SUCCEEDED.value,
|
||||
],
|
||||
failed_statuses=[
|
||||
TaskStatus.FAILED.value,
|
||||
TaskStatus.CANCELLED.value,
|
||||
],
|
||||
status_extractor=lambda response: response.status.value,
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=get_video_url_from_task_status,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
progress_extractor=extract_progress_from_task_status,
|
||||
).execute()
|
||||
|
||||
|
||||
def extract_progress_from_task_status(
|
||||
response: TaskStatusResponse,
|
||||
) -> Union[float, None]:
|
||||
@@ -103,32 +132,42 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
|
||||
|
||||
|
||||
async def get_response(
|
||||
cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None
|
||||
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None, estimated_duration: Optional[int] = None
|
||||
) -> TaskStatusResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.status.value,
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
estimated_duration=estimated_duration,
|
||||
progress_extractor=extract_progress_from_task_status,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
|
||||
async def generate_video(
|
||||
cls: type[IO.ComfyNode],
|
||||
request: RunwayImageToVideoRequest,
|
||||
auth_kwargs: dict[str, str],
|
||||
node_id: Optional[str] = None,
|
||||
estimated_duration: Optional[int] = None,
|
||||
) -> VideoFromFile:
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||
response_model=RunwayImageToVideoResponse,
|
||||
data=request,
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_IMAGE_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=RunwayImageToVideoRequest,
|
||||
response_model=RunwayImageToVideoResponse,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
final_response = await get_response(cls, initial_response.id, estimated_duration)
|
||||
initial_response = await initial_operation.execute()
|
||||
|
||||
final_response = await get_response(initial_response.id, auth_kwargs, node_id, estimated_duration)
|
||||
if not final_response.output:
|
||||
raise RunwayApiError("Runway task succeeded but no video data found in response.")
|
||||
|
||||
@@ -145,9 +184,9 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
|
||||
display_name="Runway Image to Video (Gen3a Turbo)",
|
||||
category="api node/video/Runway",
|
||||
description="Generate a video from a single starting frame using Gen3a Turbo model. "
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.",
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@@ -202,16 +241,20 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
|
||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
start_frame,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
return IO.NodeOutput(
|
||||
await generate_video(
|
||||
cls,
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
@@ -219,9 +262,15 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")]
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -235,9 +284,9 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
|
||||
display_name="Runway Image to Video (Gen4 Turbo)",
|
||||
category="api node/video/Runway",
|
||||
description="Generate a video from a single starting frame using Gen4 Turbo model. "
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.",
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@@ -292,16 +341,20 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
|
||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
start_frame,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
return IO.NodeOutput(
|
||||
await generate_video(
|
||||
cls,
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
@@ -309,9 +362,15 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")]
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||
)
|
||||
)
|
||||
@@ -326,12 +385,12 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
||||
display_name="Runway First-Last-Frame to Video",
|
||||
category="api node/video/Runway",
|
||||
description="Upload first and last keyframes, draft a prompt, and generate a video. "
|
||||
"More complex transitions, such as cases where the Last frame is completely different "
|
||||
"from the First frame, may benefit from the longer 10s duration. "
|
||||
"This would give the generation more time to smoothly transition between the two inputs. "
|
||||
"Before diving in, review these best practices to ensure that your input selections "
|
||||
"will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.",
|
||||
"More complex transitions, such as cases where the Last frame is completely different "
|
||||
"from the First frame, may benefit from the longer 10s duration. "
|
||||
"This would give the generation more time to smoothly transition between the two inputs. "
|
||||
"Before diving in, review these best practices to ensure that your input selections "
|
||||
"will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@@ -393,19 +452,23 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
stacked_input_images,
|
||||
max_images=2,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
if len(download_urls) != 2:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return IO.NodeOutput(
|
||||
await generate_video(
|
||||
cls,
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
@@ -414,11 +477,17 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first"),
|
||||
RunwayPromptImageDetailedObject(uri=str(download_urls[1]), position="last"),
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
),
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[1]), position="last"
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||
)
|
||||
)
|
||||
@@ -433,7 +502,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
||||
display_name="Runway Text to Image",
|
||||
category="api node/image/Runway",
|
||||
description="Generate an image from a text prompt using Runway's Gen 4 model. "
|
||||
"You can also include reference image to guide the generation.",
|
||||
"You can also include reference image to guide the generation.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@@ -471,34 +540,49 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1)
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
# Prepare reference images if provided
|
||||
reference_images = None
|
||||
if reference_image is not None:
|
||||
validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
|
||||
validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
reference_image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
reference_images = [ReferenceImage(uri=str(download_urls[0]))]
|
||||
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=PATH_TEXT_TO_IMAGE, method="POST"),
|
||||
response_model=RunwayTextToImageResponse,
|
||||
data=RunwayTextToImageRequest(
|
||||
promptText=prompt,
|
||||
model=Model4.gen4_image,
|
||||
ratio=ratio,
|
||||
referenceImages=reference_images,
|
||||
),
|
||||
request = RunwayTextToImageRequest(
|
||||
promptText=prompt,
|
||||
model=Model4.gen4_image,
|
||||
ratio=ratio,
|
||||
referenceImages=reference_images,
|
||||
)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_TEXT_TO_IMAGE,
|
||||
method=HttpMethod.POST,
|
||||
request_model=RunwayTextToImageRequest,
|
||||
response_model=RunwayTextToImageResponse,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
initial_response = await initial_operation.execute()
|
||||
|
||||
# Poll for completion
|
||||
final_response = await get_response(
|
||||
cls,
|
||||
initial_response.id,
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
|
||||
)
|
||||
if not final_response.output:
|
||||
@@ -517,6 +601,5 @@ class RunwayExtension(ComfyExtension):
|
||||
RunwayTextToImageNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> RunwayExtension:
|
||||
return RunwayExtension()
|
||||
|
||||
@@ -1,20 +1,23 @@
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.util import (
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.util.validation_utils import get_number_of_images
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_video_output,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_bytesio,
|
||||
)
|
||||
|
||||
|
||||
class Sora2GenerationRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
model: str = Field(...)
|
||||
@@ -77,7 +80,7 @@ class OpenAIVideoSora2(IO.ComfyNode):
|
||||
control_after_generate=True,
|
||||
optional=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
@@ -108,34 +111,55 @@ class OpenAIVideoSora2(IO.ComfyNode):
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Currently only one input image is supported.")
|
||||
files_input = {"input_reference": ("image.png", tensor_to_bytesio(image), "image/png")}
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/openai/v1/videos", method="POST"),
|
||||
data=Sora2GenerationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
seconds=str(duration),
|
||||
size=size,
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
payload = Sora2GenerationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
seconds=str(duration),
|
||||
size=size,
|
||||
)
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/openai/v1/videos",
|
||||
method=HttpMethod.POST,
|
||||
request_model=Sora2GenerationRequest,
|
||||
response_model=Sora2GenerationResponse
|
||||
),
|
||||
request=payload,
|
||||
files=files_input,
|
||||
response_model=Sora2GenerationResponse,
|
||||
auth_kwargs=auth,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
initial_response = await initial_operation.execute()
|
||||
if initial_response.error:
|
||||
raise Exception(initial_response.error["message"])
|
||||
raise Exception(initial_response.error.message)
|
||||
|
||||
model_time_multiplier = 1 if model == "sora-2" else 2
|
||||
await poll_op(
|
||||
cls,
|
||||
poll_endpoint=ApiEndpoint(path=f"/proxy/openai/v1/videos/{initial_response.id}"),
|
||||
response_model=Sora2GenerationResponse,
|
||||
poll_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/openai/v1/videos/{initial_response.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=Sora2GenerationResponse
|
||||
),
|
||||
completed_statuses=["completed"],
|
||||
failed_statuses=["failed"],
|
||||
status_extractor=lambda x: x.status,
|
||||
auth_kwargs=auth,
|
||||
poll_interval=8.0,
|
||||
max_poll_attempts=160,
|
||||
estimated_duration=int(45 * (duration / 4) * model_time_multiplier),
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=45 * (duration / 4) * model_time_multiplier,
|
||||
)
|
||||
await poll_operation.execute()
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_video_output(f"/proxy/openai/v1/videos/{initial_response.id}/content", cls=cls),
|
||||
await download_url_to_video_output(
|
||||
f"/proxy/openai/v1/videos/{initial_response.id}/content",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -27,14 +27,14 @@ from comfy_api_nodes.apis.client import (
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
validate_audio_duration,
|
||||
validate_string,
|
||||
audio_input_to_mp3,
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
bytesio_to_image_tensor,
|
||||
tensor_to_bytesio,
|
||||
validate_string,
|
||||
audio_bytes_to_audio_input,
|
||||
audio_input_to_mp3,
|
||||
)
|
||||
from comfy_api_nodes.util.validation_utils import validate_audio_duration
|
||||
|
||||
import torch
|
||||
import base64
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,21 +1,28 @@
|
||||
import logging
|
||||
import base64
|
||||
import aiohttp
|
||||
import torch
|
||||
from io import BytesIO
|
||||
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.veo_api import (
|
||||
VeoGenVidPollRequest,
|
||||
VeoGenVidPollResponse,
|
||||
from comfy_api_nodes.apis import (
|
||||
VeoGenVidRequest,
|
||||
VeoGenVidResponse,
|
||||
VeoGenVidPollRequest,
|
||||
VeoGenVidPollResponse,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
downscale_image_tensor,
|
||||
tensor_to_base64_string,
|
||||
)
|
||||
|
||||
@@ -28,6 +35,28 @@ MODELS_MAP = {
|
||||
"veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001",
|
||||
}
|
||||
|
||||
def convert_image_to_base64(image: torch.Tensor):
|
||||
if image is None:
|
||||
return None
|
||||
|
||||
scaled_image = downscale_image_tensor(image, total_pixels=2048*2048)
|
||||
return tensor_to_base64_string(scaled_image)
|
||||
|
||||
|
||||
def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optional[str]:
|
||||
if (
|
||||
poll_response.response
|
||||
and hasattr(poll_response.response, "videos")
|
||||
and poll_response.response.videos
|
||||
and len(poll_response.response.videos) > 0
|
||||
):
|
||||
video = poll_response.response.videos[0]
|
||||
else:
|
||||
return None
|
||||
if hasattr(video, "gcsUri") and video.gcsUri:
|
||||
return str(video.gcsUri)
|
||||
return None
|
||||
|
||||
|
||||
class VeoVideoGenerationNode(IO.ComfyNode):
|
||||
"""
|
||||
@@ -140,13 +169,18 @@ class VeoVideoGenerationNode(IO.ComfyNode):
|
||||
# Prepare the instances for the request
|
||||
instances = []
|
||||
|
||||
instance = {"prompt": prompt}
|
||||
instance = {
|
||||
"prompt": prompt
|
||||
}
|
||||
|
||||
# Add image if provided
|
||||
if image is not None:
|
||||
image_base64 = tensor_to_base64_string(image)
|
||||
image_base64 = convert_image_to_base64(image)
|
||||
if image_base64:
|
||||
instance["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"}
|
||||
instance["image"] = {
|
||||
"bytesBase64Encoded": image_base64,
|
||||
"mimeType": "image/png"
|
||||
}
|
||||
|
||||
instances.append(instance)
|
||||
|
||||
@@ -164,77 +198,119 @@ class VeoVideoGenerationNode(IO.ComfyNode):
|
||||
if seed > 0:
|
||||
parameters["seed"] = seed
|
||||
# Only add generateAudio for Veo 3 models
|
||||
if model.find("veo-2.0") == -1:
|
||||
if "veo-3.0" in model:
|
||||
parameters["generateAudio"] = generate_audio
|
||||
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"),
|
||||
response_model=VeoGenVidResponse,
|
||||
data=VeoGenVidRequest(
|
||||
instances=instances,
|
||||
parameters=parameters,
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
# Initial request to start video generation
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=f"/proxy/veo/{model}/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=VeoGenVidRequest,
|
||||
response_model=VeoGenVidResponse
|
||||
),
|
||||
request=VeoGenVidRequest(
|
||||
instances=instances,
|
||||
parameters=parameters
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
initial_response = await initial_operation.execute()
|
||||
operation_name = initial_response.name
|
||||
|
||||
logging.info("Veo generation started with operation name: %s", operation_name)
|
||||
|
||||
# Define status extractor function
|
||||
def status_extractor(response):
|
||||
# Only return "completed" if the operation is done, regardless of success or failure
|
||||
# We'll check for errors after polling completes
|
||||
return "completed" if response.done else "pending"
|
||||
|
||||
poll_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"),
|
||||
response_model=VeoGenVidPollResponse,
|
||||
status_extractor=status_extractor,
|
||||
data=VeoGenVidPollRequest(
|
||||
operationName=initial_response.name,
|
||||
# Define progress extractor function
|
||||
def progress_extractor(response):
|
||||
# Could be enhanced if the API provides progress information
|
||||
return None
|
||||
|
||||
# Define the polling operation
|
||||
poll_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/veo/{model}/poll",
|
||||
method=HttpMethod.POST,
|
||||
request_model=VeoGenVidPollRequest,
|
||||
response_model=VeoGenVidPollResponse
|
||||
),
|
||||
completed_statuses=["completed"],
|
||||
failed_statuses=[], # No failed statuses, we'll handle errors after polling
|
||||
status_extractor=status_extractor,
|
||||
progress_extractor=progress_extractor,
|
||||
request=VeoGenVidPollRequest(
|
||||
operationName=operation_name
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
poll_interval=5.0,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
||||
)
|
||||
|
||||
# Execute the polling operation
|
||||
poll_response = await poll_operation.execute()
|
||||
|
||||
# Now check for errors in the final response
|
||||
# Check for error in poll response
|
||||
if poll_response.error:
|
||||
raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})")
|
||||
if hasattr(poll_response, 'error') and poll_response.error:
|
||||
error_message = f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})"
|
||||
logging.error(error_message)
|
||||
raise Exception(error_message)
|
||||
|
||||
# Check for RAI filtered content
|
||||
if (
|
||||
hasattr(poll_response.response, "raiMediaFilteredCount")
|
||||
and poll_response.response.raiMediaFilteredCount > 0
|
||||
):
|
||||
if (hasattr(poll_response.response, 'raiMediaFilteredCount') and
|
||||
poll_response.response.raiMediaFilteredCount > 0):
|
||||
|
||||
# Extract reason message if available
|
||||
if (
|
||||
hasattr(poll_response.response, "raiMediaFilteredReasons")
|
||||
and poll_response.response.raiMediaFilteredReasons
|
||||
):
|
||||
if (hasattr(poll_response.response, 'raiMediaFilteredReasons') and
|
||||
poll_response.response.raiMediaFilteredReasons):
|
||||
reason = poll_response.response.raiMediaFilteredReasons[0]
|
||||
error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
|
||||
else:
|
||||
error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
|
||||
|
||||
logging.error(error_message)
|
||||
raise Exception(error_message)
|
||||
|
||||
# Extract video data
|
||||
if (
|
||||
poll_response.response
|
||||
and hasattr(poll_response.response, "videos")
|
||||
and poll_response.response.videos
|
||||
and len(poll_response.response.videos) > 0
|
||||
):
|
||||
if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0:
|
||||
video = poll_response.response.videos[0]
|
||||
|
||||
# Check if video is provided as base64 or URL
|
||||
if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||
if hasattr(video, 'bytesBase64Encoded') and video.bytesBase64Encoded:
|
||||
# Decode base64 string to bytes
|
||||
video_data = base64.b64decode(video.bytesBase64Encoded)
|
||||
elif hasattr(video, 'gcsUri') and video.gcsUri:
|
||||
# Download from URL
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(video.gcsUri) as video_response:
|
||||
video_data = await video_response.content.read()
|
||||
else:
|
||||
raise Exception("Video returned but no data or URL was provided")
|
||||
else:
|
||||
raise Exception("Video generation completed but no video was returned")
|
||||
|
||||
if hasattr(video, "gcsUri") and video.gcsUri:
|
||||
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
|
||||
if not video_data:
|
||||
raise Exception("No video data was returned")
|
||||
|
||||
raise Exception("Video returned but no data or URL was provided")
|
||||
raise Exception("Video generation completed but no video was returned")
|
||||
logging.info("Video generation completed successfully")
|
||||
|
||||
# Convert video data to BytesIO object
|
||||
video_io = BytesIO(video_data)
|
||||
|
||||
# Return VideoFromFile object
|
||||
return IO.NodeOutput(VideoFromFile(video_io))
|
||||
|
||||
|
||||
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
@@ -317,12 +393,7 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=[
|
||||
"veo-3.1-generate",
|
||||
"veo-3.1-fast-generate",
|
||||
"veo-3.0-generate-001",
|
||||
"veo-3.0-fast-generate-001",
|
||||
],
|
||||
options=list(MODELS_MAP.keys()),
|
||||
default="veo-3.0-generate-001",
|
||||
tooltip="Veo 3 model to use for video generation",
|
||||
optional=True,
|
||||
@@ -354,6 +425,5 @@ class VeoExtension(ComfyExtension):
|
||||
Veo3VideoGenerationNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> VeoExtension:
|
||||
return VeoExtension()
|
||||
|
||||
@@ -1,23 +1,27 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, TypeVar
|
||||
from typing import Any, Callable, Optional, Literal, TypeVar
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_video_output,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api_nodes.util.validation_utils import (
|
||||
validate_aspect_ratio_closeness,
|
||||
validate_image_aspect_ratio_range,
|
||||
validate_image_dimensions,
|
||||
validate_image_aspect_ratio_range,
|
||||
get_number_of_images,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import download_url_to_video_output, upload_images_to_comfyapi
|
||||
|
||||
|
||||
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
|
||||
VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video"
|
||||
@@ -27,9 +31,8 @@ VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations"
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class VideoModelName(str, Enum):
|
||||
vidu_q1 = "viduq1"
|
||||
vidu_q1 = 'viduq1'
|
||||
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
@@ -60,9 +63,17 @@ class TaskCreationRequest(BaseModel):
|
||||
images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL")
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
created = "created"
|
||||
queueing = "queueing"
|
||||
processing = "processing"
|
||||
success = "success"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
task_id: str = Field(...)
|
||||
state: str = Field(...)
|
||||
state: TaskStatus = Field(...)
|
||||
created_at: str = Field(...)
|
||||
code: Optional[int] = Field(None, description="Error code")
|
||||
|
||||
@@ -74,11 +85,32 @@ class TaskResult(BaseModel):
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
state: str = Field(...)
|
||||
state: TaskStatus = Field(...)
|
||||
err_code: Optional[str] = Field(None)
|
||||
creations: list[TaskResult] = Field(..., description="Generated results")
|
||||
|
||||
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, R],
|
||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||
estimated_duration: Optional[int] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> R:
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[TaskStatus.success.value],
|
||||
failed_statuses=[TaskStatus.failed.value],
|
||||
status_extractor=lambda response: response.state.value,
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=result_url_extractor,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
poll_interval=16.0,
|
||||
max_poll_attempts=256,
|
||||
).execute()
|
||||
|
||||
|
||||
def get_video_url_from_response(response) -> Optional[str]:
|
||||
if response.creations:
|
||||
return response.creations[0].url
|
||||
@@ -95,27 +127,37 @@ def get_video_from_response(response) -> TaskResult:
|
||||
|
||||
|
||||
async def execute_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
vidu_endpoint: str,
|
||||
auth_kwargs: Optional[dict[str, str]],
|
||||
payload: TaskCreationRequest,
|
||||
estimated_duration: int,
|
||||
node_id: str,
|
||||
) -> R:
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"),
|
||||
response_model=TaskCreationResponse,
|
||||
data=payload,
|
||||
)
|
||||
if response.state == "failed":
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=vidu_endpoint,
|
||||
method=HttpMethod.POST,
|
||||
request_model=TaskCreationRequest,
|
||||
response_model=TaskCreationResponse,
|
||||
),
|
||||
request=payload,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
if response.state == TaskStatus.failed:
|
||||
error_msg = f"Vidu request failed. Code: {response.code}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.state.value,
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=VIDU_GET_GENERATION_STATUS % response.task_id,
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -216,7 +258,11 @@ class ViduTextToVideoNode(IO.ComfyNode):
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload, 320)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
results = await execute_task(VIDU_TEXT_TO_VIDEO, auth, payload, 320, cls.hidden.unique_id)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
|
||||
|
||||
@@ -316,13 +362,17 @@ class ViduImageToVideoNode(IO.ComfyNode):
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
payload.images = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload, 120)
|
||||
results = await execute_task(VIDU_IMAGE_TO_VIDEO, auth, payload, 120, cls.hidden.unique_id)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
|
||||
|
||||
@@ -434,13 +484,17 @@ class ViduReferenceVideoNode(IO.ComfyNode):
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
payload.images = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
images,
|
||||
max_images=7,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload, 120)
|
||||
results = await execute_task(VIDU_REFERENCE_VIDEO, auth, payload, 120, cls.hidden.unique_id)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
|
||||
|
||||
@@ -542,11 +596,15 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
payload.images = [
|
||||
(await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
|
||||
(await upload_images_to_comfyapi(frame, max_images=1, mime_type="image/png", auth_kwargs=auth))[0]
|
||||
for frame in (first_frame, end_frame)
|
||||
]
|
||||
results = await execute_task(cls, VIDU_START_END_VIDEO, payload, 96)
|
||||
results = await execute_task(VIDU_START_END_VIDEO, auth, payload, 96, cls.hidden.unique_id)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
|
||||
|
||||
@@ -560,6 +618,5 @@ class ViduExtension(ComfyExtension):
|
||||
ViduStartEndToVideoNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ViduExtension:
|
||||
return ViduExtension()
|
||||
|
||||
@@ -1,24 +1,28 @@
|
||||
import re
|
||||
from typing import Optional
|
||||
from typing import Optional, Type, Union
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.util import (
|
||||
from comfy_api.latest import ComfyExtension, Input, IO
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
audio_to_base64_string,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
R,
|
||||
T,
|
||||
)
|
||||
from comfy_api_nodes.util.validation_utils import get_number_of_images, validate_audio_duration
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
validate_audio_duration,
|
||||
audio_to_base64_string,
|
||||
)
|
||||
|
||||
|
||||
class Text2ImageInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
@@ -142,7 +146,53 @@ class VideoTaskStatusResponse(BaseModel):
|
||||
request_id: str = Field(...)
|
||||
|
||||
|
||||
RES_IN_PARENS = re.compile(r"\((\d+)\s*[x×]\s*(\d+)\)")
|
||||
RES_IN_PARENS = re.compile(r'\((\d+)\s*[x×]\s*(\d+)\)')
|
||||
|
||||
|
||||
async def process_task(
|
||||
auth_kwargs: dict[str, str],
|
||||
url: str,
|
||||
request_model: Type[T],
|
||||
response_model: Type[R],
|
||||
payload: Union[
|
||||
Text2ImageTaskCreationRequest,
|
||||
Image2ImageTaskCreationRequest,
|
||||
Text2VideoTaskCreationRequest,
|
||||
Image2VideoTaskCreationRequest,
|
||||
],
|
||||
node_id: str,
|
||||
estimated_duration: int,
|
||||
poll_interval: int,
|
||||
) -> Type[R]:
|
||||
initial_response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=url,
|
||||
method=HttpMethod.POST,
|
||||
request_model=request_model,
|
||||
response_model=TaskCreationResponse,
|
||||
),
|
||||
request=payload,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
|
||||
if not initial_response.output:
|
||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
|
||||
return await PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=response_model,
|
||||
),
|
||||
completed_statuses=["SUCCEEDED"],
|
||||
failed_statuses=["FAILED", "CANCELED", "UNKNOWN"],
|
||||
status_extractor=lambda x: x.output.task_status,
|
||||
estimated_duration=estimated_duration,
|
||||
poll_interval=poll_interval,
|
||||
node_id=node_id,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
|
||||
|
||||
class WanTextToImageApi(IO.ComfyNode):
|
||||
@@ -209,7 +259,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the result.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@@ -236,28 +286,26 @@ class WanTextToImageApi(IO.ComfyNode):
|
||||
prompt_extend: bool = True,
|
||||
watermark: bool = True,
|
||||
):
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/text2image/image-synthesis", method="POST"),
|
||||
response_model=TaskCreationResponse,
|
||||
data=Text2ImageTaskCreationRequest(
|
||||
model=model,
|
||||
input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt),
|
||||
parameters=Txt2ImageParametersField(
|
||||
size=f"{width}*{height}",
|
||||
seed=seed,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
),
|
||||
payload = Text2ImageTaskCreationRequest(
|
||||
model=model,
|
||||
input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt),
|
||||
parameters=Txt2ImageParametersField(
|
||||
size=f"{width}*{height}",
|
||||
seed=seed,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
),
|
||||
)
|
||||
if not initial_response.output:
|
||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||
response = await process_task(
|
||||
{
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
"/proxy/wan/api/v1/services/aigc/text2image/image-synthesis",
|
||||
request_model=Text2ImageTaskCreationRequest,
|
||||
response_model=ImageTaskStatusResponse,
|
||||
status_extractor=lambda x: x.output.task_status,
|
||||
payload=payload,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=9,
|
||||
poll_interval=3,
|
||||
)
|
||||
@@ -272,7 +320,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
||||
display_name="Wan Image to Image",
|
||||
category="api node/image/Wan",
|
||||
description="Generates an image from one or two input images and a text prompt. "
|
||||
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
|
||||
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
@@ -328,7 +376,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the result.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@@ -360,30 +408,28 @@ class WanImageToImageApi(IO.ComfyNode):
|
||||
raise ValueError(f"Expected 1 or 2 input images, got {n_images}.")
|
||||
images = []
|
||||
for i in image:
|
||||
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096))
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/image2image/image-synthesis", method="POST"),
|
||||
response_model=TaskCreationResponse,
|
||||
data=Image2ImageTaskCreationRequest(
|
||||
model=model,
|
||||
input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images),
|
||||
parameters=Image2ImageParametersField(
|
||||
# size=f"{width}*{height}",
|
||||
seed=seed,
|
||||
watermark=watermark,
|
||||
),
|
||||
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096*4096))
|
||||
payload = Image2ImageTaskCreationRequest(
|
||||
model=model,
|
||||
input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images),
|
||||
parameters=Image2ImageParametersField(
|
||||
# size=f"{width}*{height}",
|
||||
seed=seed,
|
||||
watermark=watermark,
|
||||
),
|
||||
)
|
||||
if not initial_response.output:
|
||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||
response = await process_task(
|
||||
{
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
"/proxy/wan/api/v1/services/aigc/image2image/image-synthesis",
|
||||
request_model=Image2ImageTaskCreationRequest,
|
||||
response_model=ImageTaskStatusResponse,
|
||||
status_extractor=lambda x: x.output.task_status,
|
||||
payload=payload,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=42,
|
||||
poll_interval=4,
|
||||
poll_interval=3,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
|
||||
|
||||
@@ -477,7 +523,7 @@ class WanTextToVideoApi(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the result.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@@ -511,31 +557,28 @@ class WanTextToVideoApi(IO.ComfyNode):
|
||||
if audio is not None:
|
||||
validate_audio_duration(audio, 3.0, 29.0)
|
||||
audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame")
|
||||
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"),
|
||||
response_model=TaskCreationResponse,
|
||||
data=Text2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url),
|
||||
parameters=Text2VideoParametersField(
|
||||
size=f"{width}*{height}",
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
audio=generate_audio,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
),
|
||||
payload = Text2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url),
|
||||
parameters=Text2VideoParametersField(
|
||||
size=f"{width}*{height}",
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
audio=generate_audio,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
),
|
||||
)
|
||||
if not initial_response.output:
|
||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||
response = await process_task(
|
||||
{
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
"/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
|
||||
request_model=Text2VideoTaskCreationRequest,
|
||||
response_model=VideoTaskStatusResponse,
|
||||
status_extractor=lambda x: x.output.task_status,
|
||||
payload=payload,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=120 * int(duration / 5),
|
||||
poll_interval=6,
|
||||
)
|
||||
@@ -624,7 +667,7 @@ class WanImageToVideoApi(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the result.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@@ -656,37 +699,35 @@ class WanImageToVideoApi(IO.ComfyNode):
|
||||
):
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Exactly one input image is required.")
|
||||
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000)
|
||||
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000*2000)
|
||||
audio_url = None
|
||||
if audio is not None:
|
||||
validate_audio_duration(audio, 3.0, 29.0)
|
||||
audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame")
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"),
|
||||
response_model=TaskCreationResponse,
|
||||
data=Image2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
input=Image2VideoInputField(
|
||||
prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url
|
||||
),
|
||||
parameters=Image2VideoParametersField(
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
audio=generate_audio,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
),
|
||||
payload = Image2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
input=Image2VideoInputField(
|
||||
prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url
|
||||
),
|
||||
parameters=Image2VideoParametersField(
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
audio=generate_audio,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
),
|
||||
)
|
||||
if not initial_response.output:
|
||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||
response = await process_task(
|
||||
{
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
"/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
|
||||
request_model=Image2VideoTaskCreationRequest,
|
||||
response_model=VideoTaskStatusResponse,
|
||||
status_extractor=lambda x: x.output.task_status,
|
||||
payload=payload,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=120 * int(duration / 5),
|
||||
poll_interval=6,
|
||||
)
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
from ._helpers import get_fs_object_size
|
||||
from .client import (
|
||||
ApiEndpoint,
|
||||
poll_op,
|
||||
poll_op_raw,
|
||||
sync_op,
|
||||
sync_op_raw,
|
||||
)
|
||||
from .conversions import (
|
||||
audio_bytes_to_audio_input,
|
||||
audio_input_to_mp3,
|
||||
audio_to_base64_string,
|
||||
bytesio_to_image_tensor,
|
||||
downscale_image_tensor,
|
||||
image_tensor_pair_to_batch,
|
||||
pil_to_bytesio,
|
||||
tensor_to_base64_string,
|
||||
tensor_to_bytesio,
|
||||
tensor_to_pil,
|
||||
trim_video,
|
||||
video_to_base64_string,
|
||||
)
|
||||
from .download_helpers import (
|
||||
download_url_as_bytesio,
|
||||
download_url_to_bytesio,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
)
|
||||
from .upload_helpers import (
|
||||
upload_audio_to_comfyapi,
|
||||
upload_file_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
)
|
||||
from .validation_utils import (
|
||||
get_number_of_images,
|
||||
validate_aspect_ratio_closeness,
|
||||
validate_audio_duration,
|
||||
validate_container_format_is_mp4,
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_aspect_ratio_range,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
validate_video_dimensions,
|
||||
validate_video_duration,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# API client
|
||||
"ApiEndpoint",
|
||||
"poll_op",
|
||||
"poll_op_raw",
|
||||
"sync_op",
|
||||
"sync_op_raw",
|
||||
# Upload helpers
|
||||
"upload_audio_to_comfyapi",
|
||||
"upload_file_to_comfyapi",
|
||||
"upload_images_to_comfyapi",
|
||||
"upload_video_to_comfyapi",
|
||||
# Download helpers
|
||||
"download_url_as_bytesio",
|
||||
"download_url_to_bytesio",
|
||||
"download_url_to_image_tensor",
|
||||
"download_url_to_video_output",
|
||||
# Conversions
|
||||
"audio_bytes_to_audio_input",
|
||||
"audio_input_to_mp3",
|
||||
"audio_to_base64_string",
|
||||
"bytesio_to_image_tensor",
|
||||
"downscale_image_tensor",
|
||||
"image_tensor_pair_to_batch",
|
||||
"pil_to_bytesio",
|
||||
"tensor_to_base64_string",
|
||||
"tensor_to_bytesio",
|
||||
"tensor_to_pil",
|
||||
"trim_video",
|
||||
"video_to_base64_string",
|
||||
# Validation utilities
|
||||
"get_number_of_images",
|
||||
"validate_aspect_ratio_closeness",
|
||||
"validate_audio_duration",
|
||||
"validate_container_format_is_mp4",
|
||||
"validate_image_aspect_ratio",
|
||||
"validate_image_aspect_ratio_range",
|
||||
"validate_image_dimensions",
|
||||
"validate_string",
|
||||
"validate_video_dimensions",
|
||||
"validate_video_duration",
|
||||
# Misc functions
|
||||
"get_fs_object_size",
|
||||
]
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import time
|
||||
from io import BytesIO
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from comfy.cli_args import args
|
||||
from comfy.model_management import processing_interrupted
|
||||
from comfy_api.latest import IO
|
||||
|
||||
from .common_exceptions import ProcessingInterrupted
|
||||
|
||||
|
||||
def is_processing_interrupted() -> bool:
|
||||
"""Return True if user/runtime requested interruption."""
|
||||
return processing_interrupted()
|
||||
|
||||
|
||||
def get_node_id(node_cls: type[IO.ComfyNode]) -> str:
|
||||
return node_cls.hidden.unique_id
|
||||
|
||||
|
||||
def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
|
||||
if node_cls.hidden.auth_token_comfy_org:
|
||||
return {"Authorization": f"Bearer {node_cls.hidden.auth_token_comfy_org}"}
|
||||
if node_cls.hidden.api_key_comfy_org:
|
||||
return {"X-API-KEY": node_cls.hidden.api_key_comfy_org}
|
||||
return {}
|
||||
|
||||
|
||||
def default_base_url() -> str:
|
||||
return getattr(args, "comfy_api_base", "https://api.comfy.org")
|
||||
|
||||
|
||||
async def sleep_with_interrupt(
|
||||
seconds: float,
|
||||
node_cls: Optional[type[IO.ComfyNode]],
|
||||
label: Optional[str] = None,
|
||||
start_ts: Optional[float] = None,
|
||||
estimated_total: Optional[int] = None,
|
||||
*,
|
||||
display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None,
|
||||
):
|
||||
"""
|
||||
Sleep in 1s slices while:
|
||||
- Checking for interruption (raises ProcessingInterrupted).
|
||||
- Optionally emitting time progress via display_callback (if provided).
|
||||
"""
|
||||
end = time.monotonic() + seconds
|
||||
while True:
|
||||
if is_processing_interrupted():
|
||||
raise ProcessingInterrupted("Task cancelled")
|
||||
now = time.monotonic()
|
||||
if start_ts is not None and label and display_callback:
|
||||
with contextlib.suppress(Exception):
|
||||
display_callback(node_cls, label, int(now - start_ts), estimated_total)
|
||||
if now >= end:
|
||||
break
|
||||
await asyncio.sleep(min(1.0, end - now))
|
||||
|
||||
|
||||
def mimetype_to_extension(mime_type: str) -> str:
|
||||
"""Converts a MIME type to a file extension."""
|
||||
return mime_type.split("/")[-1].lower()
|
||||
|
||||
|
||||
def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int:
|
||||
if isinstance(path_or_object, str):
|
||||
return os.path.getsize(path_or_object)
|
||||
return len(path_or_object.getvalue())
|
||||
@@ -1,936 +0,0 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import aiohttp
|
||||
from aiohttp.client_exceptions import ClientError, ContentTypeError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from comfy import utils
|
||||
from comfy_api.latest import IO
|
||||
from comfy_api_nodes.apis import request_logger
|
||||
from server import PromptServer
|
||||
|
||||
from ._helpers import (
|
||||
default_base_url,
|
||||
get_auth_header,
|
||||
get_node_id,
|
||||
is_processing_interrupted,
|
||||
sleep_with_interrupt,
|
||||
)
|
||||
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
||||
|
||||
M = TypeVar("M", bound=BaseModel)
|
||||
|
||||
|
||||
class ApiEndpoint:
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
|
||||
*,
|
||||
query_params: Optional[dict[str, Any]] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
):
|
||||
self.path = path
|
||||
self.method = method
|
||||
self.query_params = query_params or {}
|
||||
self.headers = headers or {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class _RequestConfig:
|
||||
node_cls: type[IO.ComfyNode]
|
||||
endpoint: ApiEndpoint
|
||||
timeout: float
|
||||
content_type: str
|
||||
data: Optional[dict[str, Any]]
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]]
|
||||
multipart_parser: Optional[Callable]
|
||||
max_retries: int
|
||||
retry_delay: float
|
||||
retry_backoff: float
|
||||
wait_label: str = "Waiting"
|
||||
monitor_progress: bool = True
|
||||
estimated_total: Optional[int] = None
|
||||
final_label_on_success: Optional[str] = "Completed"
|
||||
progress_origin_ts: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PollUIState:
|
||||
started: float
|
||||
status_label: str = "Queued"
|
||||
is_queued: bool = True
|
||||
price: Optional[float] = None
|
||||
estimated_duration: Optional[int] = None
|
||||
base_processing_elapsed: float = 0.0 # sum of completed active intervals
|
||||
active_since: Optional[float] = None # start time of current active interval (None if queued)
|
||||
|
||||
|
||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"]
|
||||
FAILED_STATUSES = ["cancelled", "canceled", "failed", "error"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
|
||||
|
||||
|
||||
async def sync_op(
|
||||
cls: type[IO.ComfyNode],
|
||||
endpoint: ApiEndpoint,
|
||||
*,
|
||||
response_model: Type[M],
|
||||
data: Optional[BaseModel] = None,
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
||||
content_type: str = "application/json",
|
||||
timeout: float = 3600.0,
|
||||
multipart_parser: Optional[Callable] = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
wait_label: str = "Waiting for server",
|
||||
estimated_duration: Optional[int] = None,
|
||||
final_label_on_success: Optional[str] = "Completed",
|
||||
progress_origin_ts: Optional[float] = None,
|
||||
monitor_progress: bool = True,
|
||||
) -> M:
|
||||
raw = await sync_op_raw(
|
||||
cls,
|
||||
endpoint,
|
||||
data=data,
|
||||
files=files,
|
||||
content_type=content_type,
|
||||
timeout=timeout,
|
||||
multipart_parser=multipart_parser,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
retry_backoff=retry_backoff,
|
||||
wait_label=wait_label,
|
||||
estimated_duration=estimated_duration,
|
||||
as_binary=False,
|
||||
final_label_on_success=final_label_on_success,
|
||||
progress_origin_ts=progress_origin_ts,
|
||||
monitor_progress=monitor_progress,
|
||||
)
|
||||
if not isinstance(raw, dict):
|
||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||
return _validate_or_raise(response_model, raw)
|
||||
|
||||
|
||||
async def poll_op(
|
||||
cls: type[IO.ComfyNode],
|
||||
poll_endpoint: ApiEndpoint,
|
||||
*,
|
||||
response_model: Type[M],
|
||||
status_extractor: Callable[[M], Optional[Union[str, int]]],
|
||||
progress_extractor: Optional[Callable[[M], Optional[int]]] = None,
|
||||
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
|
||||
completed_statuses: Optional[list[Union[str, int]]] = None,
|
||||
failed_statuses: Optional[list[Union[str, int]]] = None,
|
||||
queued_statuses: Optional[list[Union[str, int]]] = None,
|
||||
data: Optional[BaseModel] = None,
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 120,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 3,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
retry_backoff_per_poll: float = 2.0,
|
||||
estimated_duration: Optional[int] = None,
|
||||
cancel_endpoint: Optional[ApiEndpoint] = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
) -> M:
|
||||
raw = await poll_op_raw(
|
||||
cls,
|
||||
poll_endpoint=poll_endpoint,
|
||||
status_extractor=_wrap_model_extractor(response_model, status_extractor),
|
||||
progress_extractor=_wrap_model_extractor(response_model, progress_extractor),
|
||||
price_extractor=_wrap_model_extractor(response_model, price_extractor),
|
||||
completed_statuses=completed_statuses,
|
||||
failed_statuses=failed_statuses,
|
||||
queued_statuses=queued_statuses,
|
||||
data=data,
|
||||
poll_interval=poll_interval,
|
||||
max_poll_attempts=max_poll_attempts,
|
||||
timeout_per_poll=timeout_per_poll,
|
||||
max_retries_per_poll=max_retries_per_poll,
|
||||
retry_delay_per_poll=retry_delay_per_poll,
|
||||
retry_backoff_per_poll=retry_backoff_per_poll,
|
||||
estimated_duration=estimated_duration,
|
||||
cancel_endpoint=cancel_endpoint,
|
||||
cancel_timeout=cancel_timeout,
|
||||
)
|
||||
if not isinstance(raw, dict):
|
||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||
return _validate_or_raise(response_model, raw)
|
||||
|
||||
|
||||
async def sync_op_raw(
|
||||
cls: type[IO.ComfyNode],
|
||||
endpoint: ApiEndpoint,
|
||||
*,
|
||||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
||||
content_type: str = "application/json",
|
||||
timeout: float = 3600.0,
|
||||
multipart_parser: Optional[Callable] = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
wait_label: str = "Waiting for server",
|
||||
estimated_duration: Optional[int] = None,
|
||||
as_binary: bool = False,
|
||||
final_label_on_success: Optional[str] = "Completed",
|
||||
progress_origin_ts: Optional[float] = None,
|
||||
monitor_progress: bool = True,
|
||||
) -> Union[dict[str, Any], bytes]:
|
||||
"""
|
||||
Make a single network request.
|
||||
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
||||
- If as_binary=True: returns bytes.
|
||||
"""
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump(exclude_none=True)
|
||||
for k, v in list(data.items()):
|
||||
if isinstance(v, Enum):
|
||||
data[k] = v.value
|
||||
cfg = _RequestConfig(
|
||||
node_cls=cls,
|
||||
endpoint=endpoint,
|
||||
timeout=timeout,
|
||||
content_type=content_type,
|
||||
data=data,
|
||||
files=files,
|
||||
multipart_parser=multipart_parser,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
retry_backoff=retry_backoff,
|
||||
wait_label=wait_label,
|
||||
monitor_progress=monitor_progress,
|
||||
estimated_total=estimated_duration,
|
||||
final_label_on_success=final_label_on_success,
|
||||
progress_origin_ts=progress_origin_ts,
|
||||
)
|
||||
return await _request_base(cfg, expect_binary=as_binary)
|
||||
|
||||
|
||||
async def poll_op_raw(
|
||||
cls: type[IO.ComfyNode],
|
||||
poll_endpoint: ApiEndpoint,
|
||||
*,
|
||||
status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]],
|
||||
progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None,
|
||||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
|
||||
completed_statuses: Optional[list[Union[str, int]]] = None,
|
||||
failed_statuses: Optional[list[Union[str, int]]] = None,
|
||||
queued_statuses: Optional[list[Union[str, int]]] = None,
|
||||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 120,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 3,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
retry_backoff_per_poll: float = 2.0,
|
||||
estimated_duration: Optional[int] = None,
|
||||
cancel_endpoint: Optional[ApiEndpoint] = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing,
|
||||
checks interruption every second, and calls Cancel endpoint (if provided) on interruption.
|
||||
|
||||
Uses default complete, failed and queued states assumption.
|
||||
|
||||
Returns the final JSON response from the poll endpoint.
|
||||
"""
|
||||
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses)
|
||||
failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses)
|
||||
queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses)
|
||||
started = time.monotonic()
|
||||
consumed_attempts = 0 # counts only non-queued polls
|
||||
|
||||
progress_bar = utils.ProgressBar(100) if progress_extractor else None
|
||||
last_progress: Optional[int] = None
|
||||
|
||||
state = _PollUIState(started=started, estimated_duration=estimated_duration)
|
||||
stop_ticker = asyncio.Event()
|
||||
|
||||
async def _ticker():
|
||||
"""Emit a UI update every second while polling is in progress."""
|
||||
try:
|
||||
while not stop_ticker.is_set():
|
||||
if is_processing_interrupted():
|
||||
break
|
||||
now = time.monotonic()
|
||||
proc_elapsed = state.base_processing_elapsed + (
|
||||
(now - state.active_since) if state.active_since is not None else 0.0
|
||||
)
|
||||
_display_time_progress(
|
||||
cls,
|
||||
status=state.status_label,
|
||||
elapsed_seconds=int(now - state.started),
|
||||
estimated_total=state.estimated_duration,
|
||||
price=state.price,
|
||||
is_queued=state.is_queued,
|
||||
processing_elapsed_seconds=int(proc_elapsed),
|
||||
)
|
||||
await asyncio.sleep(1.0)
|
||||
except Exception as exc:
|
||||
logging.debug("Polling ticker exited: %s", exc)
|
||||
|
||||
ticker_task = asyncio.create_task(_ticker())
|
||||
try:
|
||||
while consumed_attempts < max_poll_attempts:
|
||||
try:
|
||||
resp_json = await sync_op_raw(
|
||||
cls,
|
||||
poll_endpoint,
|
||||
data=data,
|
||||
timeout=timeout_per_poll,
|
||||
max_retries=max_retries_per_poll,
|
||||
retry_delay=retry_delay_per_poll,
|
||||
retry_backoff=retry_backoff_per_poll,
|
||||
wait_label="Checking",
|
||||
estimated_duration=None,
|
||||
as_binary=False,
|
||||
final_label_on_success=None,
|
||||
monitor_progress=False,
|
||||
)
|
||||
if not isinstance(resp_json, dict):
|
||||
raise Exception("Polling endpoint returned non-JSON response.")
|
||||
except ProcessingInterrupted:
|
||||
if cancel_endpoint:
|
||||
with contextlib.suppress(Exception):
|
||||
await sync_op_raw(
|
||||
cls,
|
||||
cancel_endpoint,
|
||||
timeout=cancel_timeout,
|
||||
max_retries=0,
|
||||
wait_label="Cancelling task",
|
||||
estimated_duration=None,
|
||||
as_binary=False,
|
||||
final_label_on_success=None,
|
||||
monitor_progress=False,
|
||||
)
|
||||
raise
|
||||
|
||||
try:
|
||||
status = _normalize_status_value(status_extractor(resp_json))
|
||||
except Exception as e:
|
||||
logging.error("Status extraction failed: %s", e)
|
||||
status = None
|
||||
|
||||
if price_extractor:
|
||||
new_price = price_extractor(resp_json)
|
||||
if new_price is not None:
|
||||
state.price = new_price
|
||||
|
||||
if progress_extractor:
|
||||
new_progress = progress_extractor(resp_json)
|
||||
if new_progress is not None and last_progress != new_progress:
|
||||
progress_bar.update_absolute(new_progress, total=100)
|
||||
last_progress = new_progress
|
||||
|
||||
now_ts = time.monotonic()
|
||||
is_queued = status in queued_states
|
||||
|
||||
if is_queued:
|
||||
if state.active_since is not None: # If we just moved from active -> queued, close the active interval
|
||||
state.base_processing_elapsed += now_ts - state.active_since
|
||||
state.active_since = None
|
||||
else:
|
||||
if state.active_since is None: # If we just moved from queued -> active, open a new active interval
|
||||
state.active_since = now_ts
|
||||
|
||||
state.is_queued = is_queued
|
||||
state.status_label = status or ("Queued" if is_queued else "Processing")
|
||||
if status in completed_states:
|
||||
if state.active_since is not None:
|
||||
state.base_processing_elapsed += now_ts - state.active_since
|
||||
state.active_since = None
|
||||
stop_ticker.set()
|
||||
with contextlib.suppress(Exception):
|
||||
await ticker_task
|
||||
|
||||
if progress_bar and last_progress != 100:
|
||||
progress_bar.update_absolute(100, total=100)
|
||||
|
||||
_display_time_progress(
|
||||
cls,
|
||||
status=status if status else "Completed",
|
||||
elapsed_seconds=int(now_ts - started),
|
||||
estimated_total=estimated_duration,
|
||||
price=state.price,
|
||||
is_queued=False,
|
||||
processing_elapsed_seconds=int(state.base_processing_elapsed),
|
||||
)
|
||||
return resp_json
|
||||
|
||||
if status in failed_states:
|
||||
msg = f"Task failed: {json.dumps(resp_json)}"
|
||||
logging.error(msg)
|
||||
raise Exception(msg)
|
||||
|
||||
try:
|
||||
await sleep_with_interrupt(poll_interval, cls, None, None, None)
|
||||
except ProcessingInterrupted:
|
||||
if cancel_endpoint:
|
||||
with contextlib.suppress(Exception):
|
||||
await sync_op_raw(
|
||||
cls,
|
||||
cancel_endpoint,
|
||||
timeout=cancel_timeout,
|
||||
max_retries=0,
|
||||
wait_label="Cancelling task",
|
||||
estimated_duration=None,
|
||||
as_binary=False,
|
||||
final_label_on_success=None,
|
||||
monitor_progress=False,
|
||||
)
|
||||
raise
|
||||
if not is_queued:
|
||||
consumed_attempts += 1
|
||||
|
||||
raise Exception(
|
||||
f"Polling timed out after {max_poll_attempts} non-queued attempts "
|
||||
f"(~{int(max_poll_attempts * poll_interval)}s of active polling)."
|
||||
)
|
||||
except ProcessingInterrupted:
|
||||
raise
|
||||
except (LocalNetworkError, ApiServerError):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise Exception(f"Polling aborted due to error: {e}") from e
|
||||
finally:
|
||||
stop_ticker.set()
|
||||
with contextlib.suppress(Exception):
|
||||
await ticker_task
|
||||
|
||||
|
||||
def _display_text(
|
||||
node_cls: type[IO.ComfyNode],
|
||||
text: Optional[str],
|
||||
*,
|
||||
status: Optional[Union[str, int]] = None,
|
||||
price: Optional[float] = None,
|
||||
) -> None:
|
||||
display_lines: list[str] = []
|
||||
if status:
|
||||
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
|
||||
if price is not None:
|
||||
display_lines.append(f"Price: ${float(price):,.4f}")
|
||||
if text is not None:
|
||||
display_lines.append(text)
|
||||
if display_lines:
|
||||
PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls))
|
||||
|
||||
|
||||
def _display_time_progress(
|
||||
node_cls: type[IO.ComfyNode],
|
||||
status: Optional[Union[str, int]],
|
||||
elapsed_seconds: int,
|
||||
estimated_total: Optional[int] = None,
|
||||
*,
|
||||
price: Optional[float] = None,
|
||||
is_queued: Optional[bool] = None,
|
||||
processing_elapsed_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
if estimated_total is not None and estimated_total > 0 and is_queued is False:
|
||||
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
|
||||
remaining = max(0, int(estimated_total) - int(pe))
|
||||
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
|
||||
else:
|
||||
time_line = f"Time elapsed: {int(elapsed_seconds)}s"
|
||||
_display_text(node_cls, time_line, status=status, price=price)
|
||||
|
||||
|
||||
async def _diagnose_connectivity() -> dict[str, bool]:
|
||||
"""Best-effort connectivity diagnostics to distinguish local vs. server issues."""
|
||||
results = {
|
||||
"internet_accessible": False,
|
||||
"api_accessible": False,
|
||||
}
|
||||
timeout = aiohttp.ClientTimeout(total=5.0)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
with contextlib.suppress(ClientError, OSError):
|
||||
async with session.get("https://www.google.com") as resp:
|
||||
results["internet_accessible"] = resp.status < 500
|
||||
if not results["internet_accessible"]:
|
||||
return results
|
||||
|
||||
parsed = urlparse(default_base_url())
|
||||
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
|
||||
with contextlib.suppress(ClientError, OSError):
|
||||
async with session.get(health_url) as resp:
|
||||
results["api_accessible"] = resp.status < 500
|
||||
return results
|
||||
|
||||
|
||||
def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
|
||||
"""Normalize (filename, value, content_type)."""
|
||||
if len(t) == 2:
|
||||
return t[0], t[1], "application/octet-stream"
|
||||
if len(t) == 3:
|
||||
return t[0], t[1], t[2]
|
||||
raise ValueError("files tuple must be (filename, file[, content_type])")
|
||||
|
||||
|
||||
def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]:
|
||||
params = dict(endpoint_params or {})
|
||||
if method.upper() == "GET" and data:
|
||||
for k, v in data.items():
|
||||
if v is not None:
|
||||
params[k] = v
|
||||
return params
|
||||
|
||||
|
||||
def _friendly_http_message(status: int, body: Any) -> str:
|
||||
if status == 401:
|
||||
return "Unauthorized: Please login first to use this node."
|
||||
if status == 402:
|
||||
return "Payment Required: Please add credits to your account to use this node."
|
||||
if status == 409:
|
||||
return "There is a problem with your account. Please contact support@comfy.org."
|
||||
if status == 429:
|
||||
return "Rate Limit Exceeded: Please try again later."
|
||||
try:
|
||||
if isinstance(body, dict):
|
||||
err = body.get("error")
|
||||
if isinstance(err, dict):
|
||||
msg = err.get("message")
|
||||
typ = err.get("type")
|
||||
if msg and typ:
|
||||
return f"API Error: {msg} (Type: {typ})"
|
||||
if msg:
|
||||
return f"API Error: {msg}"
|
||||
return f"API Error: {json.dumps(body)}"
|
||||
else:
|
||||
txt = str(body)
|
||||
if len(txt) <= 200:
|
||||
return f"API Error (raw): {txt}"
|
||||
return f"API Error (status {status})"
|
||||
except Exception:
|
||||
return f"HTTP {status}: Unknown error"
|
||||
|
||||
|
||||
def _generate_operation_id(method: str, path: str, attempt: int) -> str:
|
||||
slug = path.strip("/").replace("/", "_") or "op"
|
||||
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
def _snapshot_request_body_for_logging(
|
||||
content_type: str,
|
||||
method: str,
|
||||
data: Optional[dict[str, Any]],
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]],
|
||||
) -> Optional[Union[dict[str, Any], str]]:
|
||||
if method.upper() == "GET":
|
||||
return None
|
||||
if content_type == "multipart/form-data":
|
||||
form_fields = sorted([k for k, v in (data or {}).items() if v is not None])
|
||||
file_fields: list[dict[str, str]] = []
|
||||
if files:
|
||||
file_iter = files if isinstance(files, list) else list(files.items())
|
||||
for field_name, file_obj in file_iter:
|
||||
if file_obj is None:
|
||||
continue
|
||||
if isinstance(file_obj, tuple):
|
||||
filename = file_obj[0]
|
||||
else:
|
||||
filename = getattr(file_obj, "name", field_name)
|
||||
file_fields.append({"field": field_name, "filename": str(filename or "")})
|
||||
return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields}
|
||||
if content_type == "application/x-www-form-urlencoded":
|
||||
return data or {}
|
||||
return data or {}
|
||||
|
||||
|
||||
async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
"""Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors."""
|
||||
url = cfg.endpoint.path
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
||||
|
||||
method = cfg.endpoint.method
|
||||
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
|
||||
|
||||
async def _monitor(stop_evt: asyncio.Event, start_ts: float):
|
||||
"""Every second: update elapsed time and signal interruption."""
|
||||
try:
|
||||
while not stop_evt.is_set():
|
||||
if is_processing_interrupted():
|
||||
return
|
||||
if cfg.monitor_progress:
|
||||
_display_time_progress(
|
||||
cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total
|
||||
)
|
||||
await asyncio.sleep(1.0)
|
||||
except asyncio.CancelledError:
|
||||
return # normal shutdown
|
||||
|
||||
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
|
||||
attempt = 0
|
||||
delay = cfg.retry_delay
|
||||
operation_succeeded: bool = False
|
||||
final_elapsed_seconds: Optional[int] = None
|
||||
while True:
|
||||
attempt += 1
|
||||
stop_event = asyncio.Event()
|
||||
monitor_task: Optional[asyncio.Task] = None
|
||||
sess: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
||||
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
||||
|
||||
payload_headers = {"Accept": "*/*"}
|
||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||
payload_headers.update(get_auth_header(cfg.node_cls))
|
||||
if cfg.endpoint.headers:
|
||||
payload_headers.update(cfg.endpoint.headers)
|
||||
|
||||
payload_kw: dict[str, Any] = {"headers": payload_headers}
|
||||
if method == "GET":
|
||||
payload_headers.pop("Content-Type", None)
|
||||
request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files)
|
||||
try:
|
||||
if cfg.monitor_progress:
|
||||
monitor_task = asyncio.create_task(_monitor(stop_event, start_time))
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=cfg.timeout)
|
||||
sess = aiohttp.ClientSession(timeout=timeout)
|
||||
|
||||
if cfg.content_type == "multipart/form-data" and method != "GET":
|
||||
# aiohttp will set Content-Type boundary; remove any fixed Content-Type
|
||||
payload_headers.pop("Content-Type", None)
|
||||
if cfg.multipart_parser and cfg.data:
|
||||
form = cfg.multipart_parser(cfg.data)
|
||||
if not isinstance(form, aiohttp.FormData):
|
||||
raise ValueError("multipart_parser must return aiohttp.FormData")
|
||||
else:
|
||||
form = aiohttp.FormData(default_to_multipart=True)
|
||||
if cfg.data:
|
||||
for k, v in cfg.data.items():
|
||||
if v is None:
|
||||
continue
|
||||
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
|
||||
if cfg.files:
|
||||
file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items()
|
||||
for field_name, file_obj in file_iter:
|
||||
if file_obj is None:
|
||||
continue
|
||||
if isinstance(file_obj, tuple):
|
||||
filename, file_value, content_type = _unpack_tuple(file_obj)
|
||||
else:
|
||||
filename = getattr(file_obj, "name", field_name)
|
||||
file_value = file_obj
|
||||
content_type = "application/octet-stream"
|
||||
# Attempt to rewind BytesIO for retries
|
||||
if isinstance(file_value, BytesIO):
|
||||
with contextlib.suppress(Exception):
|
||||
file_value.seek(0)
|
||||
form.add_field(field_name, file_value, filename=filename, content_type=content_type)
|
||||
payload_kw["data"] = form
|
||||
elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET":
|
||||
payload_headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
payload_kw["data"] = cfg.data or {}
|
||||
elif method != "GET":
|
||||
payload_headers["Content-Type"] = "application/json"
|
||||
payload_kw["json"] = cfg.data or {}
|
||||
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] request logging failed: %s", _log_e)
|
||||
|
||||
req_coro = sess.request(method, url, params=params, **payload_kw)
|
||||
req_task = asyncio.create_task(req_coro)
|
||||
|
||||
# Race: request vs. monitor (interruption)
|
||||
tasks = {req_task}
|
||||
if monitor_task:
|
||||
tasks.add(monitor_task)
|
||||
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
if monitor_task and monitor_task in done:
|
||||
# Interrupted – cancel the request and abort
|
||||
if req_task in pending:
|
||||
req_task.cancel()
|
||||
raise ProcessingInterrupted("Task cancelled")
|
||||
|
||||
# Otherwise, request finished
|
||||
resp = await req_task
|
||||
async with resp:
|
||||
if resp.status >= 400:
|
||||
try:
|
||||
body = await resp.json()
|
||||
except (ContentTypeError, json.JSONDecodeError):
|
||||
body = await resp.text()
|
||||
if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries:
|
||||
logging.warning(
|
||||
"HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).",
|
||||
method,
|
||||
url,
|
||||
resp.status,
|
||||
delay,
|
||||
attempt,
|
||||
cfg.max_retries,
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=_friendly_http_message(resp.status, body),
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cfg.node_cls,
|
||||
cfg.wait_label if cfg.monitor_progress else None,
|
||||
start_time if cfg.monitor_progress else None,
|
||||
cfg.estimated_total,
|
||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||
)
|
||||
delay *= cfg.retry_backoff
|
||||
continue
|
||||
msg = _friendly_http_message(resp.status, body)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=msg,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
raise Exception(msg)
|
||||
|
||||
if expect_binary:
|
||||
buff = bytearray()
|
||||
last_tick = time.monotonic()
|
||||
async for chunk in resp.content.iter_chunked(64 * 1024):
|
||||
buff.extend(chunk)
|
||||
now = time.monotonic()
|
||||
if now - last_tick >= 1.0:
|
||||
last_tick = now
|
||||
if is_processing_interrupted():
|
||||
raise ProcessingInterrupted("Task cancelled")
|
||||
if cfg.monitor_progress:
|
||||
_display_time_progress(
|
||||
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
|
||||
)
|
||||
bytes_payload = bytes(buff)
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=bytes_payload,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
return bytes_payload
|
||||
else:
|
||||
try:
|
||||
payload = await resp.json()
|
||||
response_content_to_log: Any = payload
|
||||
except (ContentTypeError, json.JSONDecodeError):
|
||||
text = await resp.text()
|
||||
try:
|
||||
payload = json.loads(text) if text else {}
|
||||
except json.JSONDecodeError:
|
||||
payload = {"_raw": text}
|
||||
response_content_to_log = payload if isinstance(payload, dict) else text
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=response_content_to_log,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
return payload
|
||||
|
||||
except ProcessingInterrupted:
|
||||
logging.debug("Polling was interrupted by user")
|
||||
raise
|
||||
except (ClientError, OSError) as e:
|
||||
if attempt <= cfg.max_retries:
|
||||
logging.warning(
|
||||
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
|
||||
method,
|
||||
url,
|
||||
delay,
|
||||
attempt,
|
||||
cfg.max_retries,
|
||||
str(e),
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] request error logging failed: %s", _log_e)
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cfg.node_cls,
|
||||
cfg.wait_label if cfg.monitor_progress else None,
|
||||
start_time if cfg.monitor_progress else None,
|
||||
cfg.estimated_total,
|
||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||
)
|
||||
delay *= cfg.retry_backoff
|
||||
continue
|
||||
diag = await _diagnose_connectivity()
|
||||
if not diag["internet_accessible"]:
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"LocalNetworkError: {str(e)}",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the API server due to local network issues. "
|
||||
"Please check your internet connection and try again."
|
||||
) from e
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"ApiServerError: {str(e)}",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||||
raise ApiServerError(
|
||||
f"The API server at {default_base_url()} is currently unreachable. "
|
||||
f"The service may be experiencing issues."
|
||||
) from e
|
||||
finally:
|
||||
stop_event.set()
|
||||
if monitor_task:
|
||||
monitor_task.cancel()
|
||||
with contextlib.suppress(Exception):
|
||||
await monitor_task
|
||||
if sess:
|
||||
with contextlib.suppress(Exception):
|
||||
await sess.close()
|
||||
if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success:
|
||||
_display_time_progress(
|
||||
cfg.node_cls,
|
||||
status=cfg.final_label_on_success,
|
||||
elapsed_seconds=(
|
||||
final_elapsed_seconds
|
||||
if final_elapsed_seconds is not None
|
||||
else int(time.monotonic() - start_time)
|
||||
),
|
||||
estimated_total=cfg.estimated_total,
|
||||
price=None,
|
||||
is_queued=False,
|
||||
processing_elapsed_seconds=final_elapsed_seconds,
|
||||
)
|
||||
|
||||
|
||||
def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
|
||||
try:
|
||||
return response_model.model_validate(payload)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
"Response validation failed for %s: %s",
|
||||
getattr(response_model, "__name__", response_model),
|
||||
e,
|
||||
)
|
||||
raise Exception(
|
||||
f"Response validation failed for {getattr(response_model, '__name__', response_model)}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def _wrap_model_extractor(
|
||||
response_model: Type[M],
|
||||
extractor: Optional[Callable[[M], Any]],
|
||||
) -> Optional[Callable[[dict[str, Any]], Any]]:
|
||||
"""Wrap a typed extractor so it can be used by the dict-based poller.
|
||||
Validates the dict into `response_model` before invoking `extractor`.
|
||||
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
|
||||
the same response for multiple extractors in a single poll attempt.
|
||||
"""
|
||||
if extractor is None:
|
||||
return None
|
||||
_cache: dict[int, M] = {}
|
||||
|
||||
def _wrapped(d: dict[str, Any]) -> Any:
|
||||
try:
|
||||
key = id(d)
|
||||
model = _cache.get(key)
|
||||
if model is None:
|
||||
model = response_model.model_validate(d)
|
||||
_cache[key] = model
|
||||
return extractor(model)
|
||||
except Exception as e:
|
||||
logging.error("Extractor failed (typed -> dict wrapper): %s", e)
|
||||
raise
|
||||
|
||||
return _wrapped
|
||||
|
||||
|
||||
def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]:
|
||||
if not values:
|
||||
return set()
|
||||
out: set[Union[str, int]] = set()
|
||||
for v in values:
|
||||
nv = _normalize_status_value(v)
|
||||
if nv is not None:
|
||||
out.add(nv)
|
||||
return out
|
||||
|
||||
|
||||
def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]:
|
||||
if isinstance(val, str):
|
||||
return val.strip().lower()
|
||||
return val
|
||||
@@ -1,14 +0,0 @@
|
||||
class NetworkError(Exception):
|
||||
"""Base exception for network-related errors with diagnostic information."""
|
||||
|
||||
|
||||
class LocalNetworkError(NetworkError):
|
||||
"""Exception raised when local network connectivity issues are detected."""
|
||||
|
||||
|
||||
class ApiServerError(NetworkError):
|
||||
"""Exception raised when the API server is unreachable but internet is working."""
|
||||
|
||||
|
||||
class ProcessingInterrupted(Exception):
|
||||
"""Operation was interrupted by user/runtime via processing_interrupted()."""
|
||||
@@ -1,432 +0,0 @@
|
||||
import base64
|
||||
import logging
|
||||
import math
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api.latest import Input, InputImpl
|
||||
from comfy_api.util import VideoContainer, VideoCodec
|
||||
|
||||
from ._helpers import mimetype_to_extension
|
||||
|
||||
|
||||
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
|
||||
"""Converts image data from BytesIO to a torch.Tensor.
|
||||
|
||||
Args:
|
||||
image_bytesio: BytesIO object containing the image data.
|
||||
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
|
||||
|
||||
Returns:
|
||||
A torch.Tensor representing the image (1, H, W, C).
|
||||
|
||||
Raises:
|
||||
PIL.UnidentifiedImageError: If the image data cannot be identified.
|
||||
ValueError: If the specified mode is invalid.
|
||||
"""
|
||||
image = Image.open(image_bytesio)
|
||||
image = image.convert(mode)
|
||||
image_array = np.array(image).astype(np.float32) / 255.0
|
||||
return torch.from_numpy(image_array).unsqueeze(0)
|
||||
|
||||
|
||||
def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Converts a pair of image tensors to a batch tensor.
|
||||
If the images are not the same size, the smaller image is resized to
|
||||
match the larger image.
|
||||
"""
|
||||
if image1.shape[1:] != image2.shape[1:]:
|
||||
image2 = common_upscale(
|
||||
image2.movedim(-1, 1),
|
||||
image1.shape[2],
|
||||
image1.shape[1],
|
||||
"bilinear",
|
||||
"center",
|
||||
).movedim(1, -1)
|
||||
return torch.cat((image1, image2), dim=0)
|
||||
|
||||
|
||||
def tensor_to_bytesio(
|
||||
image: torch.Tensor,
|
||||
name: Optional[str] = None,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> BytesIO:
|
||||
"""Converts a torch.Tensor image to a named BytesIO object.
|
||||
|
||||
Args:
|
||||
image: Input torch.Tensor image.
|
||||
name: Optional filename for the BytesIO object.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||
|
||||
Returns:
|
||||
Named BytesIO object containing the image data, with pointer set to the start of buffer.
|
||||
"""
|
||||
if not mime_type:
|
||||
mime_type = "image/png"
|
||||
|
||||
pil_image = tensor_to_pil(image, total_pixels=total_pixels)
|
||||
img_binary = pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||
img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
|
||||
return img_binary
|
||||
|
||||
|
||||
def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
|
||||
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
|
||||
if len(image.shape) > 3:
|
||||
image = image[0]
|
||||
# TODO: remove alpha if not allowed and present
|
||||
input_tensor = image.cpu()
|
||||
input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze()
|
||||
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
return img
|
||||
|
||||
|
||||
def tensor_to_base64_string(
|
||||
image_tensor: torch.Tensor,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> str:
|
||||
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
|
||||
|
||||
Args:
|
||||
image_tensor: Input torch.Tensor image.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||
|
||||
Returns:
|
||||
Base64 encoded string of the image.
|
||||
"""
|
||||
pil_image = tensor_to_pil(image_tensor, total_pixels=total_pixels)
|
||||
img_byte_arr = pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||
img_bytes = img_byte_arr.getvalue()
|
||||
# Encode bytes to base64 string
|
||||
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
|
||||
return base64_encoded_string
|
||||
|
||||
|
||||
def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
|
||||
"""Converts a PIL Image to a BytesIO object."""
|
||||
if not mime_type:
|
||||
mime_type = "image/png"
|
||||
|
||||
img_byte_arr = BytesIO()
|
||||
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
|
||||
pil_format = mime_type.split("/")[-1].upper()
|
||||
if pil_format == "JPG":
|
||||
pil_format = "JPEG"
|
||||
img.save(img_byte_arr, format=pil_format)
|
||||
img_byte_arr.seek(0)
|
||||
return img_byte_arr
|
||||
|
||||
|
||||
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
||||
"""Downscale input image tensor to roughly the specified total pixels."""
|
||||
samples = image.movedim(-1, 1)
|
||||
total = int(total_pixels)
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
if scale_by >= 1:
|
||||
return image
|
||||
width = round(samples.shape[3] * scale_by)
|
||||
height = round(samples.shape[2] * scale_by)
|
||||
|
||||
s = common_upscale(samples, width, height, "lanczos", "disabled")
|
||||
s = s.movedim(1, -1)
|
||||
return s
|
||||
|
||||
|
||||
def tensor_to_data_uri(
|
||||
image_tensor: torch.Tensor,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> str:
|
||||
"""Converts a tensor image to a Data URI string.
|
||||
|
||||
Args:
|
||||
image_tensor: Input torch.Tensor image.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
|
||||
|
||||
Returns:
|
||||
Data URI string (e.g., 'data:image/png;base64,...').
|
||||
"""
|
||||
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
|
||||
|
||||
def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", codec_name: str = "aac") -> str:
|
||||
"""Converts an audio input to a base64 string."""
|
||||
sample_rate: int = audio["sample_rate"]
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
|
||||
audio_bytes = audio_bytes_io.getvalue()
|
||||
return base64.b64encode(audio_bytes).decode("utf-8")
|
||||
|
||||
|
||||
def video_to_base64_string(
|
||||
video: Input.Video,
|
||||
container_format: VideoContainer = None,
|
||||
codec: VideoCodec = None
|
||||
) -> str:
|
||||
"""
|
||||
Converts a video input to a base64 string.
|
||||
|
||||
Args:
|
||||
video: The video input to convert
|
||||
container_format: Optional container format to use (defaults to video.container if available)
|
||||
codec: Optional codec to use (defaults to video.codec if available)
|
||||
"""
|
||||
video_bytes_io = BytesIO()
|
||||
|
||||
# Use provided format/codec if specified, otherwise use video's own if available
|
||||
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
|
||||
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
|
||||
|
||||
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
|
||||
video_bytes_io.seek(0)
|
||||
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def audio_ndarray_to_bytesio(
|
||||
audio_data_np: np.ndarray,
|
||||
sample_rate: int,
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
) -> BytesIO:
|
||||
"""
|
||||
Encodes a numpy array of audio data into a BytesIO object.
|
||||
"""
|
||||
audio_bytes_io = BytesIO()
|
||||
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
|
||||
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
|
||||
frame = av.AudioFrame.from_ndarray(
|
||||
audio_data_np,
|
||||
format="fltp",
|
||||
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
|
||||
)
|
||||
frame.sample_rate = sample_rate
|
||||
frame.pts = 0
|
||||
|
||||
for packet in audio_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
|
||||
# Flush stream
|
||||
for packet in audio_stream.encode(None):
|
||||
output_container.mux(packet)
|
||||
|
||||
audio_bytes_io.seek(0)
|
||||
return audio_bytes_io
|
||||
|
||||
|
||||
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
|
||||
"""
|
||||
Prepares audio waveform for av library by converting to a contiguous numpy array.
|
||||
|
||||
Args:
|
||||
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
|
||||
|
||||
Returns:
|
||||
Contiguous numpy array of the audio waveform. If the audio was batched,
|
||||
the first item is taken.
|
||||
"""
|
||||
if waveform.ndim != 3 or waveform.shape[0] != 1:
|
||||
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
|
||||
|
||||
# If batch is > 1, take first item
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = waveform[0]
|
||||
|
||||
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
|
||||
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
|
||||
if audio_data_np.dtype != np.float32:
|
||||
audio_data_np = audio_data_np.astype(np.float32)
|
||||
|
||||
return audio_data_np
|
||||
|
||||
|
||||
def audio_input_to_mp3(audio: Input.Audio) -> BytesIO:
|
||||
waveform = audio["waveform"].cpu()
|
||||
|
||||
output_buffer = BytesIO()
|
||||
output_container = av.open(output_buffer, mode="w", format="mp3")
|
||||
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
|
||||
out_stream.bit_rate = 320000
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(
|
||||
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
|
||||
format="flt",
|
||||
layout="mono" if waveform.shape[0] == 1 else "stereo",
|
||||
)
|
||||
frame.sample_rate = audio["sample_rate"]
|
||||
frame.pts = 0
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
output_container.mux(out_stream.encode(None))
|
||||
output_container.close()
|
||||
output_buffer.seek(0)
|
||||
return output_buffer
|
||||
|
||||
|
||||
def trim_video(video: Input.Video, duration_sec: float) -> Input.Video:
|
||||
"""
|
||||
Returns a new VideoInput object trimmed from the beginning to the specified duration,
|
||||
using av to avoid loading entire video into memory.
|
||||
|
||||
Args:
|
||||
video: Input video to trim
|
||||
duration_sec: Duration in seconds to keep from the beginning
|
||||
|
||||
Returns:
|
||||
VideoFromFile object that owns the output buffer
|
||||
"""
|
||||
output_buffer = BytesIO()
|
||||
input_container = None
|
||||
output_container = None
|
||||
|
||||
try:
|
||||
# Get the stream source - this avoids loading entire video into memory
|
||||
# when the source is already a file path
|
||||
input_source = video.get_stream_source()
|
||||
|
||||
# Open containers
|
||||
input_container = av.open(input_source, mode="r")
|
||||
output_container = av.open(output_buffer, mode="w", format="mp4")
|
||||
|
||||
# Set up output streams for re-encoding
|
||||
video_stream = None
|
||||
audio_stream = None
|
||||
|
||||
for stream in input_container.streams:
|
||||
logging.info("Found stream: type=%s, class=%s", stream.type, type(stream))
|
||||
if isinstance(stream, av.VideoStream):
|
||||
# Create output video stream with same parameters
|
||||
video_stream = output_container.add_stream("h264", rate=stream.average_rate)
|
||||
video_stream.width = stream.width
|
||||
video_stream.height = stream.height
|
||||
video_stream.pix_fmt = "yuv420p"
|
||||
logging.info("Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate)
|
||||
elif isinstance(stream, av.AudioStream):
|
||||
# Create output audio stream with same parameters
|
||||
audio_stream = output_container.add_stream("aac", rate=stream.sample_rate)
|
||||
audio_stream.sample_rate = stream.sample_rate
|
||||
audio_stream.layout = stream.layout
|
||||
logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels)
|
||||
|
||||
# Calculate target frame count that's divisible by 16
|
||||
fps = input_container.streams.video[0].average_rate
|
||||
estimated_frames = int(duration_sec * fps)
|
||||
target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16
|
||||
|
||||
if target_frames == 0:
|
||||
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
|
||||
|
||||
frame_count = 0
|
||||
audio_frame_count = 0
|
||||
|
||||
# Decode and re-encode video frames
|
||||
if video_stream:
|
||||
for frame in input_container.decode(video=0):
|
||||
if frame_count >= target_frames:
|
||||
break
|
||||
|
||||
# Re-encode frame
|
||||
for packet in video_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
for packet in video_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames)
|
||||
|
||||
# Decode and re-encode audio frames
|
||||
if audio_stream:
|
||||
input_container.seek(0) # Reset to beginning for audio
|
||||
for frame in input_container.decode(audio=0):
|
||||
if frame.time >= duration_sec:
|
||||
break
|
||||
|
||||
# Re-encode frame
|
||||
for packet in audio_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
audio_frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
for packet in audio_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
logging.info("Encoded %s audio frames", audio_frame_count)
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
input_container.close()
|
||||
|
||||
# Return as VideoFromFile using the buffer
|
||||
output_buffer.seek(0)
|
||||
return InputImpl.VideoFromFile(output_buffer)
|
||||
|
||||
except Exception as e:
|
||||
# Clean up on error
|
||||
if input_container is not None:
|
||||
input_container.close()
|
||||
if output_container is not None:
|
||||
output_container.close()
|
||||
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
|
||||
|
||||
|
||||
def _f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
|
||||
if wav.dtype.is_floating_point:
|
||||
return wav
|
||||
elif wav.dtype == torch.int16:
|
||||
return wav.float() / (2**15)
|
||||
elif wav.dtype == torch.int32:
|
||||
return wav.float() / (2**31)
|
||||
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
||||
|
||||
|
||||
def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict:
|
||||
"""
|
||||
Decode any common audio container from bytes using PyAV and return
|
||||
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
|
||||
"""
|
||||
with av.open(BytesIO(audio_bytes)) as af:
|
||||
if not af.streams.audio:
|
||||
raise ValueError("No audio stream found in response.")
|
||||
stream = af.streams.audio[0]
|
||||
|
||||
in_sr = int(stream.codec_context.sample_rate)
|
||||
out_sr = in_sr
|
||||
|
||||
frames: list[torch.Tensor] = []
|
||||
n_channels = stream.channels or 1
|
||||
|
||||
for frame in af.decode(streams=stream.index):
|
||||
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
|
||||
buf = torch.from_numpy(arr)
|
||||
if buf.ndim == 1:
|
||||
buf = buf.unsqueeze(0) # [T] -> [1, T]
|
||||
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
|
||||
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
|
||||
elif buf.shape[0] != n_channels:
|
||||
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
|
||||
frames.append(buf)
|
||||
|
||||
if not frames:
|
||||
raise ValueError("Decoded zero audio frames.")
|
||||
|
||||
wav = torch.cat(frames, dim=1) # [C, T]
|
||||
wav = _f32_pcm(wav)
|
||||
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
|
||||
@@ -1,261 +0,0 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import IO, Optional, Union
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from aiohttp.client_exceptions import ClientError, ContentTypeError
|
||||
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import IO as COMFY_IO
|
||||
from comfy_api_nodes.apis import request_logger
|
||||
|
||||
from ._helpers import (
|
||||
default_base_url,
|
||||
get_auth_header,
|
||||
is_processing_interrupted,
|
||||
sleep_with_interrupt,
|
||||
)
|
||||
from .client import _diagnose_connectivity
|
||||
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
||||
from .conversions import bytesio_to_image_tensor
|
||||
|
||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||
|
||||
|
||||
async def download_url_to_bytesio(
|
||||
url: str,
|
||||
dest: Optional[Union[BytesIO, IO[bytes], str, Path]],
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
max_retries: int = 5,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
cls: type[COMFY_IO.ComfyNode] = None,
|
||||
) -> None:
|
||||
"""Stream-download a URL to `dest`.
|
||||
|
||||
`dest` must be one of:
|
||||
- a BytesIO (rewound to 0 after write),
|
||||
- a file-like object opened in binary write mode (must implement .write()),
|
||||
- a filesystem path (str | pathlib.Path), which will be opened with 'wb'.
|
||||
|
||||
If `url` starts with `/proxy/`, `cls` must be provided so the URL can be expanded
|
||||
to an absolute URL and authentication headers can be applied.
|
||||
|
||||
Raises:
|
||||
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors)
|
||||
"""
|
||||
if not isinstance(dest, (str, Path)) and not hasattr(dest, "write"):
|
||||
raise ValueError("dest must be a path (str|Path) or a binary-writable object providing .write().")
|
||||
|
||||
attempt = 0
|
||||
delay = retry_delay
|
||||
headers: dict[str, str] = {}
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||
if cls is None:
|
||||
raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.")
|
||||
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
||||
headers = get_auth_header(cls)
|
||||
|
||||
while True:
|
||||
attempt += 1
|
||||
op_id = _generate_operation_id("GET", url, attempt)
|
||||
timeout_cfg = aiohttp.ClientTimeout(total=timeout)
|
||||
|
||||
is_path_sink = isinstance(dest, (str, Path))
|
||||
fhandle = None
|
||||
session: Optional[aiohttp.ClientSession] = None
|
||||
stop_evt: Optional[asyncio.Event] = None
|
||||
monitor_task: Optional[asyncio.Task] = None
|
||||
req_task: Optional[asyncio.Task] = None
|
||||
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(operation_id=op_id, request_method="GET", request_url=url)
|
||||
|
||||
session = aiohttp.ClientSession(timeout=timeout_cfg)
|
||||
stop_evt = asyncio.Event()
|
||||
|
||||
async def _monitor():
|
||||
try:
|
||||
while not stop_evt.is_set():
|
||||
if is_processing_interrupted():
|
||||
return
|
||||
await asyncio.sleep(1.0)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
monitor_task = asyncio.create_task(_monitor())
|
||||
|
||||
req_task = asyncio.create_task(session.get(url, headers=headers))
|
||||
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
if monitor_task in done and req_task in pending:
|
||||
req_task.cancel()
|
||||
with contextlib.suppress(Exception):
|
||||
await req_task
|
||||
raise ProcessingInterrupted("Task cancelled")
|
||||
|
||||
try:
|
||||
resp = await req_task
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
|
||||
async with resp:
|
||||
if resp.status >= 400:
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
body = await resp.json()
|
||||
except (ContentTypeError, ValueError):
|
||||
text = await resp.text()
|
||||
body = text if len(text) <= 4096 else f"[text {len(text)} bytes]"
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=f"HTTP {resp.status}",
|
||||
)
|
||||
|
||||
if resp.status in _RETRY_STATUS and attempt <= max_retries:
|
||||
await sleep_with_interrupt(delay, cls, None, None, None)
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
raise Exception(f"Failed to download (HTTP {resp.status}).")
|
||||
|
||||
if is_path_sink:
|
||||
p = Path(str(dest))
|
||||
with contextlib.suppress(Exception):
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
fhandle = open(p, "wb")
|
||||
sink = fhandle
|
||||
else:
|
||||
sink = dest # BytesIO or file-like
|
||||
|
||||
written = 0
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(resp.content.read(1024 * 1024), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
chunk = b""
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
|
||||
if is_processing_interrupted():
|
||||
raise ProcessingInterrupted("Task cancelled")
|
||||
|
||||
if not chunk:
|
||||
if resp.content.at_eof():
|
||||
break
|
||||
continue
|
||||
|
||||
sink.write(chunk)
|
||||
written += len(chunk)
|
||||
|
||||
if isinstance(dest, BytesIO):
|
||||
with contextlib.suppress(Exception):
|
||||
dest.seek(0)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=f"[streamed {written} bytes to dest]",
|
||||
)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
except (ClientError, OSError) as e:
|
||||
if attempt <= max_retries:
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
await sleep_with_interrupt(delay, cls, None, None, None)
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
|
||||
diag = await _diagnose_connectivity()
|
||||
if not diag["internet_accessible"]:
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the network. Please check your internet connection and try again."
|
||||
) from e
|
||||
raise ApiServerError("The remote service appears unreachable at this time.") from e
|
||||
finally:
|
||||
if stop_evt is not None:
|
||||
stop_evt.set()
|
||||
if monitor_task:
|
||||
monitor_task.cancel()
|
||||
with contextlib.suppress(Exception):
|
||||
await monitor_task
|
||||
if req_task and not req_task.done():
|
||||
req_task.cancel()
|
||||
with contextlib.suppress(Exception):
|
||||
await req_task
|
||||
if session:
|
||||
with contextlib.suppress(Exception):
|
||||
await session.close()
|
||||
if fhandle:
|
||||
with contextlib.suppress(Exception):
|
||||
fhandle.flush()
|
||||
fhandle.close()
|
||||
|
||||
|
||||
async def download_url_to_image_tensor(
|
||||
url: str,
|
||||
*,
|
||||
timeout: float = None,
|
||||
cls: type[COMFY_IO.ComfyNode] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
|
||||
result = BytesIO()
|
||||
await download_url_to_bytesio(url, result, timeout=timeout, cls=cls)
|
||||
return bytesio_to_image_tensor(result)
|
||||
|
||||
|
||||
async def download_url_to_video_output(
|
||||
video_url: str,
|
||||
*,
|
||||
timeout: float = None,
|
||||
cls: type[COMFY_IO.ComfyNode] = None,
|
||||
) -> VideoFromFile:
|
||||
"""Downloads a video from a URL and returns a `VIDEO` output."""
|
||||
result = BytesIO()
|
||||
await download_url_to_bytesio(video_url, result, timeout=timeout, cls=cls)
|
||||
return VideoFromFile(result)
|
||||
|
||||
|
||||
async def download_url_as_bytesio(
|
||||
url: str,
|
||||
*,
|
||||
timeout: float = None,
|
||||
cls: type[COMFY_IO.ComfyNode] = None,
|
||||
) -> BytesIO:
|
||||
"""Downloads content from a URL and returns a new BytesIO (rewound to 0)."""
|
||||
result = BytesIO()
|
||||
await download_url_to_bytesio(url, result, timeout=timeout, cls=cls)
|
||||
return result
|
||||
|
||||
|
||||
def _generate_operation_id(method: str, url: str, attempt: int) -> str:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download").strip("/").replace("/", "_")
|
||||
except Exception:
|
||||
slug = "download"
|
||||
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"
|
||||
@@ -1,338 +0,0 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from comfy_api.latest import IO, Input
|
||||
from comfy_api.util import VideoCodec, VideoContainer
|
||||
from comfy_api_nodes.apis import request_logger
|
||||
|
||||
from ._helpers import is_processing_interrupted, sleep_with_interrupt
|
||||
from .client import (
|
||||
ApiEndpoint,
|
||||
_diagnose_connectivity,
|
||||
_display_time_progress,
|
||||
sync_op,
|
||||
)
|
||||
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
||||
from .conversions import (
|
||||
audio_ndarray_to_bytesio,
|
||||
audio_tensor_to_contiguous_ndarray,
|
||||
tensor_to_bytesio,
|
||||
)
|
||||
|
||||
|
||||
class UploadRequest(BaseModel):
|
||||
file_name: str = Field(..., description="Filename to upload")
|
||||
content_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
||||
)
|
||||
|
||||
|
||||
class UploadResponse(BaseModel):
|
||||
download_url: str = Field(..., description="URL to GET uploaded file")
|
||||
upload_url: str = Field(..., description="URL to PUT file to upload")
|
||||
|
||||
|
||||
async def upload_images_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
image: torch.Tensor,
|
||||
*,
|
||||
max_images: int = 8,
|
||||
mime_type: Optional[str] = None,
|
||||
wait_label: Optional[str] = "Uploading",
|
||||
) -> list[str]:
|
||||
"""
|
||||
Uploads images to ComfyUI API and returns download URLs.
|
||||
To upload multiple images, stack them in the batch dimension first.
|
||||
"""
|
||||
# if batch, try to upload each file if max_images is greater than 0
|
||||
download_urls: list[str] = []
|
||||
is_batch = len(image.shape) > 3
|
||||
batch_len = image.shape[0] if is_batch else 1
|
||||
|
||||
for idx in range(min(batch_len, max_images)):
|
||||
tensor = image[idx] if is_batch else image
|
||||
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
|
||||
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label)
|
||||
download_urls.append(url)
|
||||
return download_urls
|
||||
|
||||
|
||||
async def upload_audio_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
audio: Input.Audio,
|
||||
*,
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
mime_type: str = "audio/mp4",
|
||||
filename: str = "uploaded_audio.mp4",
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single audio input to ComfyUI API and returns its download URL.
|
||||
Encodes the raw waveform into the specified format before uploading.
|
||||
"""
|
||||
sample_rate: int = audio["sample_rate"]
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
|
||||
return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type)
|
||||
|
||||
|
||||
async def upload_video_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
video: Input.Video,
|
||||
*,
|
||||
container: VideoContainer = VideoContainer.MP4,
|
||||
codec: VideoCodec = VideoCodec.H264,
|
||||
max_duration: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single video to ComfyUI API and returns its download URL.
|
||||
Uses the specified container and codec for saving the video before upload.
|
||||
"""
|
||||
if max_duration is not None:
|
||||
try:
|
||||
actual_duration = video.get_duration()
|
||||
if actual_duration > max_duration:
|
||||
raise ValueError(
|
||||
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error("Error getting video duration: %s", str(e))
|
||||
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
||||
|
||||
upload_mime_type = f"video/{container.value.lower()}"
|
||||
filename = f"uploaded_video.{container.value.lower()}"
|
||||
|
||||
# Convert VideoInput to BytesIO using specified container/codec
|
||||
video_bytes_io = BytesIO()
|
||||
video.save_to(video_bytes_io, format=container, codec=codec)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type)
|
||||
|
||||
|
||||
async def upload_file_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
file_bytes_io: BytesIO,
|
||||
filename: str,
|
||||
upload_mime_type: Optional[str],
|
||||
wait_label: Optional[str] = "Uploading",
|
||||
) -> str:
|
||||
"""Uploads a single file to ComfyUI API and returns its download URL."""
|
||||
if upload_mime_type is None:
|
||||
request_object = UploadRequest(file_name=filename)
|
||||
else:
|
||||
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
||||
create_resp = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/customers/storage", method="POST"),
|
||||
data=request_object,
|
||||
response_model=UploadResponse,
|
||||
final_label_on_success=None,
|
||||
monitor_progress=False,
|
||||
)
|
||||
await upload_file(
|
||||
cls,
|
||||
create_resp.upload_url,
|
||||
file_bytes_io,
|
||||
content_type=upload_mime_type,
|
||||
wait_label=wait_label,
|
||||
)
|
||||
return create_resp.download_url
|
||||
|
||||
|
||||
async def upload_file(
|
||||
cls: type[IO.ComfyNode],
|
||||
upload_url: str,
|
||||
file: Union[BytesIO, str],
|
||||
*,
|
||||
content_type: Optional[str] = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
wait_label: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption.
|
||||
|
||||
Args:
|
||||
cls: Node class (provides auth context + UI progress hooks).
|
||||
upload_url: Pre-signed PUT URL.
|
||||
file: BytesIO or path string.
|
||||
content_type: Explicit MIME type. If None, we *suppress* Content-Type.
|
||||
max_retries: Maximum retry attempts.
|
||||
retry_delay: Initial delay in seconds.
|
||||
retry_backoff: Exponential backoff factor.
|
||||
wait_label: Progress label shown in Comfy UI.
|
||||
|
||||
Raises:
|
||||
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception
|
||||
"""
|
||||
if isinstance(file, BytesIO):
|
||||
with contextlib.suppress(Exception):
|
||||
file.seek(0)
|
||||
data = file.read()
|
||||
elif isinstance(file, str):
|
||||
with open(file, "rb") as f:
|
||||
data = f.read()
|
||||
else:
|
||||
raise ValueError("file must be a BytesIO or a filesystem path string")
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
skip_auto_headers: set[str] = set()
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
else:
|
||||
skip_auto_headers.add("Content-Type") # Don't let aiohttp add Content-Type, it can break the signed request
|
||||
|
||||
attempt = 0
|
||||
delay = retry_delay
|
||||
start_ts = time.monotonic()
|
||||
op_uuid = uuid.uuid4().hex[:8]
|
||||
while True:
|
||||
attempt += 1
|
||||
operation_id = _generate_operation_id("PUT", upload_url, attempt, op_uuid)
|
||||
timeout = aiohttp.ClientTimeout(total=None)
|
||||
stop_evt = asyncio.Event()
|
||||
|
||||
async def _monitor():
|
||||
try:
|
||||
while not stop_evt.is_set():
|
||||
if is_processing_interrupted():
|
||||
return
|
||||
if wait_label:
|
||||
_display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None)
|
||||
await asyncio.sleep(1.0)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
monitor_task = asyncio.create_task(_monitor())
|
||||
sess: Optional[aiohttp.ClientSession] = None
|
||||
try:
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_params=None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug("[DEBUG] upload request logging failed: %s", e)
|
||||
|
||||
sess = aiohttp.ClientSession(timeout=timeout)
|
||||
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
|
||||
req_task = asyncio.create_task(req)
|
||||
|
||||
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
if monitor_task in done and req_task in pending:
|
||||
req_task.cancel()
|
||||
raise ProcessingInterrupted("Upload cancelled")
|
||||
|
||||
try:
|
||||
resp = await req_task
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Upload cancelled") from None
|
||||
|
||||
async with resp:
|
||||
if resp.status >= 400:
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
body = await resp.json()
|
||||
except Exception:
|
||||
body = await resp.text()
|
||||
msg = f"Upload failed with status {resp.status}"
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=msg,
|
||||
)
|
||||
if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries:
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cls,
|
||||
wait_label,
|
||||
start_ts,
|
||||
None,
|
||||
display_callback=_display_time_progress if wait_label else None,
|
||||
)
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
raise Exception(f"Failed to upload (HTTP {resp.status}).")
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content="File uploaded successfully.",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug("[DEBUG] upload response logging failed: %s", e)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
except (aiohttp.ClientError, OSError) as e:
|
||||
if attempt <= max_retries:
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cls,
|
||||
wait_label,
|
||||
start_ts,
|
||||
None,
|
||||
display_callback=_display_time_progress if wait_label else None,
|
||||
)
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
|
||||
diag = await _diagnose_connectivity()
|
||||
if not diag["internet_accessible"]:
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the network. Please check your internet connection and try again."
|
||||
) from e
|
||||
raise ApiServerError("The API service appears unreachable at this time.") from e
|
||||
finally:
|
||||
stop_evt.set()
|
||||
if monitor_task:
|
||||
monitor_task.cancel()
|
||||
with contextlib.suppress(Exception):
|
||||
await monitor_task
|
||||
if sess:
|
||||
with contextlib.suppress(Exception):
|
||||
await sess.close()
|
||||
|
||||
|
||||
def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_")
|
||||
except Exception:
|
||||
slug = "upload"
|
||||
return f"{method}_{slug}_{op_uuid}_try{attempt}"
|
||||
@@ -2,8 +2,6 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy_api.latest import Input
|
||||
|
||||
|
||||
@@ -30,7 +28,9 @@ def validate_image_dimensions(
|
||||
if max_width is not None and width > max_width:
|
||||
raise ValueError(f"Image width must be at most {max_width}px, got {width}px")
|
||||
if min_height is not None and height < min_height:
|
||||
raise ValueError(f"Image height must be at least {min_height}px, got {height}px")
|
||||
raise ValueError(
|
||||
f"Image height must be at least {min_height}px, got {height}px"
|
||||
)
|
||||
if max_height is not None and height > max_height:
|
||||
raise ValueError(f"Image height must be at most {max_height}px, got {height}px")
|
||||
|
||||
@@ -44,9 +44,13 @@ def validate_image_aspect_ratio(
|
||||
aspect_ratio = width / height
|
||||
|
||||
if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
|
||||
raise ValueError(f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}")
|
||||
raise ValueError(
|
||||
f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}"
|
||||
)
|
||||
if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
|
||||
raise ValueError(f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}")
|
||||
raise ValueError(
|
||||
f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}"
|
||||
)
|
||||
|
||||
|
||||
def validate_image_aspect_ratio_range(
|
||||
@@ -54,7 +58,7 @@ def validate_image_aspect_ratio_range(
|
||||
min_ratio: tuple[float, float], # e.g. (1, 4)
|
||||
max_ratio: tuple[float, float], # e.g. (4, 1)
|
||||
*,
|
||||
strict: bool = True, # True -> (min, max); False -> [min, max]
|
||||
strict: bool = True, # True -> (min, max); False -> [min, max]
|
||||
) -> float:
|
||||
a1, b1 = min_ratio
|
||||
a2, b2 = max_ratio
|
||||
@@ -81,7 +85,7 @@ def validate_aspect_ratio_closeness(
|
||||
min_rel: float,
|
||||
max_rel: float,
|
||||
*,
|
||||
strict: bool = False, # True => exclusive, False => inclusive
|
||||
strict: bool = False, # True => exclusive, False => inclusive
|
||||
) -> None:
|
||||
w1, h1 = get_image_dimensions(start_img)
|
||||
w2, h2 = get_image_dimensions(end_img)
|
||||
@@ -114,7 +118,9 @@ def validate_video_dimensions(
|
||||
if max_width is not None and width > max_width:
|
||||
raise ValueError(f"Video width must be at most {max_width}px, got {width}px")
|
||||
if min_height is not None and height < min_height:
|
||||
raise ValueError(f"Video height must be at least {min_height}px, got {height}px")
|
||||
raise ValueError(
|
||||
f"Video height must be at least {min_height}px, got {height}px"
|
||||
)
|
||||
if max_height is not None and height > max_height:
|
||||
raise ValueError(f"Video height must be at most {max_height}px, got {height}px")
|
||||
|
||||
@@ -132,9 +138,13 @@ def validate_video_duration(
|
||||
|
||||
epsilon = 0.0001
|
||||
if min_duration is not None and min_duration - epsilon > duration:
|
||||
raise ValueError(f"Video duration must be at least {min_duration}s, got {duration}s")
|
||||
raise ValueError(
|
||||
f"Video duration must be at least {min_duration}s, got {duration}s"
|
||||
)
|
||||
if max_duration is not None and duration > max_duration + epsilon:
|
||||
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
|
||||
raise ValueError(
|
||||
f"Video duration must be at most {max_duration}s, got {duration}s"
|
||||
)
|
||||
|
||||
|
||||
def get_number_of_images(images):
|
||||
@@ -155,31 +165,3 @@ def validate_audio_duration(
|
||||
raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s")
|
||||
if max_duration is not None and dur - eps > max_duration:
|
||||
raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s")
|
||||
|
||||
|
||||
def validate_string(
|
||||
string: str,
|
||||
strip_whitespace=True,
|
||||
field_name="prompt",
|
||||
min_length=None,
|
||||
max_length=None,
|
||||
):
|
||||
if string is None:
|
||||
raise Exception(f"Field '{field_name}' cannot be empty.")
|
||||
if strip_whitespace:
|
||||
string = string.strip()
|
||||
if min_length and len(string) < min_length:
|
||||
raise Exception(
|
||||
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
|
||||
)
|
||||
if max_length and len(string) > max_length:
|
||||
raise Exception(
|
||||
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
|
||||
)
|
||||
|
||||
|
||||
def validate_container_format_is_mp4(video: VideoInput) -> None:
|
||||
"""Validates video container format is MP4."""
|
||||
container_format = video.get_container_format()
|
||||
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
||||
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
|
||||
|
||||
@@ -265,26 +265,6 @@ class HierarchicalCache(BasicCache):
|
||||
assert cache is not None
|
||||
return await cache._ensure_subcache(node_id, children_ids)
|
||||
|
||||
class NullCache:
|
||||
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
pass
|
||||
|
||||
def all_node_ids(self):
|
||||
return []
|
||||
|
||||
def clean_unused(self):
|
||||
pass
|
||||
|
||||
def get(self, node_id):
|
||||
return None
|
||||
|
||||
def set(self, node_id, value):
|
||||
pass
|
||||
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
return self
|
||||
|
||||
class LRUCache(BasicCache):
|
||||
def __init__(self, key_class, max_size=100):
|
||||
super().__init__(key_class)
|
||||
@@ -336,3 +316,157 @@ class LRUCache(BasicCache):
|
||||
self._mark_used(child_id)
|
||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||
return self
|
||||
|
||||
|
||||
class DependencyAwareCache(BasicCache):
|
||||
"""
|
||||
A cache implementation that tracks dependencies between nodes and manages
|
||||
their execution and caching accordingly. It extends the BasicCache class.
|
||||
Nodes are removed from this cache once all of their descendants have been
|
||||
executed.
|
||||
"""
|
||||
|
||||
def __init__(self, key_class):
|
||||
"""
|
||||
Initialize the DependencyAwareCache.
|
||||
|
||||
Args:
|
||||
key_class: The class used for generating cache keys.
|
||||
"""
|
||||
super().__init__(key_class)
|
||||
self.descendants = {} # Maps node_id -> set of descendant node_ids
|
||||
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
|
||||
self.executed_nodes = set() # Tracks nodes that have been executed
|
||||
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
"""
|
||||
Clear the entire cache and rebuild the dependency graph.
|
||||
|
||||
Args:
|
||||
dynprompt: The dynamic prompt object containing node information.
|
||||
node_ids: List of node IDs to initialize the cache for.
|
||||
is_changed_cache: Flag indicating if the cache has changed.
|
||||
"""
|
||||
# Clear all existing cache data
|
||||
self.cache.clear()
|
||||
self.subcaches.clear()
|
||||
self.descendants.clear()
|
||||
self.ancestors.clear()
|
||||
self.executed_nodes.clear()
|
||||
|
||||
# Call the parent method to initialize the cache with the new prompt
|
||||
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
|
||||
# Rebuild the dependency graph
|
||||
self._build_dependency_graph(dynprompt, node_ids)
|
||||
|
||||
def _build_dependency_graph(self, dynprompt, node_ids):
|
||||
"""
|
||||
Build the dependency graph for all nodes.
|
||||
|
||||
Args:
|
||||
dynprompt: The dynamic prompt object containing node information.
|
||||
node_ids: List of node IDs to build the graph for.
|
||||
"""
|
||||
self.descendants.clear()
|
||||
self.ancestors.clear()
|
||||
for node_id in node_ids:
|
||||
self.descendants[node_id] = set()
|
||||
self.ancestors[node_id] = set()
|
||||
|
||||
for node_id in node_ids:
|
||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||
for input_data in inputs.values():
|
||||
if is_link(input_data): # Check if the input is a link to another node
|
||||
ancestor_id = input_data[0]
|
||||
self.descendants[ancestor_id].add(node_id)
|
||||
self.ancestors[node_id].add(ancestor_id)
|
||||
|
||||
def set(self, node_id, value):
|
||||
"""
|
||||
Mark a node as executed and store its value in the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to store.
|
||||
value: The value to store for the node.
|
||||
"""
|
||||
self._set_immediate(node_id, value)
|
||||
self.executed_nodes.add(node_id)
|
||||
self._cleanup_ancestors(node_id)
|
||||
|
||||
def get(self, node_id):
|
||||
"""
|
||||
Retrieve the cached value for a node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to retrieve.
|
||||
|
||||
Returns:
|
||||
The cached value for the node.
|
||||
"""
|
||||
return self._get_immediate(node_id)
|
||||
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
"""
|
||||
Ensure a subcache exists for a node and update dependencies.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the parent node.
|
||||
children_ids: List of child node IDs to associate with the parent node.
|
||||
|
||||
Returns:
|
||||
The subcache object for the node.
|
||||
"""
|
||||
subcache = await super()._ensure_subcache(node_id, children_ids)
|
||||
for child_id in children_ids:
|
||||
self.descendants[node_id].add(child_id)
|
||||
self.ancestors[child_id].add(node_id)
|
||||
return subcache
|
||||
|
||||
def _cleanup_ancestors(self, node_id):
|
||||
"""
|
||||
Check if ancestors of a node can be removed from the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node whose ancestors are to be checked.
|
||||
"""
|
||||
for ancestor_id in self.ancestors.get(node_id, []):
|
||||
if ancestor_id in self.executed_nodes:
|
||||
# Remove ancestor if all its descendants have been executed
|
||||
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
|
||||
self._remove_node(ancestor_id)
|
||||
|
||||
def _remove_node(self, node_id):
|
||||
"""
|
||||
Remove a node from the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to remove.
|
||||
"""
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key in self.cache:
|
||||
del self.cache[cache_key]
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
if subcache_key in self.subcaches:
|
||||
del self.subcaches[subcache_key]
|
||||
|
||||
def clean_unused(self):
|
||||
"""
|
||||
Clean up unused nodes. This is a no-op for this cache implementation.
|
||||
"""
|
||||
pass
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
"""
|
||||
Dump the cache and dependency graph for debugging.
|
||||
|
||||
Returns:
|
||||
A list containing the cache state and dependency graph.
|
||||
"""
|
||||
result = super().recursive_debug_dump()
|
||||
result.append({
|
||||
"descendants": self.descendants,
|
||||
"ancestors": self.ancestors,
|
||||
"executed_nodes": list(self.executed_nodes),
|
||||
})
|
||||
return result
|
||||
|
||||
@@ -153,9 +153,8 @@ class TopologicalSort:
|
||||
continue
|
||||
_, _, input_info = self.get_input_info(unique_id, input_name)
|
||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||
if (include_lazy or not is_lazy):
|
||||
if not self.is_cached(from_node_id):
|
||||
node_ids.append(from_node_id)
|
||||
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
||||
node_ids.append(from_node_id)
|
||||
links.append((from_node_id, from_socket, unique_id))
|
||||
|
||||
for link in links:
|
||||
@@ -195,35 +194,10 @@ class ExecutionList(TopologicalSort):
|
||||
super().__init__(dynprompt)
|
||||
self.output_cache = output_cache
|
||||
self.staged_node_id = None
|
||||
self.execution_cache = {}
|
||||
self.execution_cache_listeners = {}
|
||||
|
||||
def is_cached(self, node_id):
|
||||
return self.output_cache.get(node_id) is not None
|
||||
|
||||
def cache_link(self, from_node_id, to_node_id):
|
||||
if not to_node_id in self.execution_cache:
|
||||
self.execution_cache[to_node_id] = {}
|
||||
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
|
||||
if not from_node_id in self.execution_cache_listeners:
|
||||
self.execution_cache_listeners[from_node_id] = set()
|
||||
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
||||
|
||||
def get_output_cache(self, from_node_id, to_node_id):
|
||||
if not to_node_id in self.execution_cache:
|
||||
return None
|
||||
return self.execution_cache[to_node_id].get(from_node_id)
|
||||
|
||||
def cache_update(self, node_id, value):
|
||||
if node_id in self.execution_cache_listeners:
|
||||
for to_node_id in self.execution_cache_listeners[node_id]:
|
||||
if to_node_id in self.execution_cache:
|
||||
self.execution_cache[to_node_id][node_id] = value
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
self.cache_link(from_node_id, to_node_id)
|
||||
|
||||
async def stage_node_execution(self):
|
||||
assert self.staged_node_id is None
|
||||
if self.is_empty():
|
||||
@@ -303,8 +277,6 @@ class ExecutionList(TopologicalSort):
|
||||
def complete_node_execution(self):
|
||||
node_id = self.staged_node_id
|
||||
self.pop_node(node_id)
|
||||
self.execution_cache.pop(node_id, None)
|
||||
self.execution_cache_listeners.pop(node_id, None)
|
||||
self.staged_node_id = None
|
||||
|
||||
def get_nodes_in_cycle(self):
|
||||
|
||||
@@ -1,26 +1,20 @@
|
||||
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
||||
import nodes
|
||||
import comfy.utils
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
class SetUnionControlNetType(io.ComfyNode):
|
||||
class SetUnionControlNetType:
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SetUnionControlNetType",
|
||||
category="conditioning/controlnet",
|
||||
inputs=[
|
||||
io.ControlNet.Input("control_net"),
|
||||
io.Combo.Input("type", options=["auto"] + list(UNION_CONTROLNET_TYPES.keys())),
|
||||
],
|
||||
outputs=[
|
||||
io.ControlNet.Output(),
|
||||
],
|
||||
)
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"control_net": ("CONTROL_NET", ),
|
||||
"type": (["auto"] + list(UNION_CONTROLNET_TYPES.keys()),)
|
||||
}}
|
||||
|
||||
@classmethod
|
||||
def execute(cls, control_net, type) -> io.NodeOutput:
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
RETURN_TYPES = ("CONTROL_NET",)
|
||||
|
||||
FUNCTION = "set_controlnet_type"
|
||||
|
||||
def set_controlnet_type(self, control_net, type):
|
||||
control_net = control_net.copy()
|
||||
type_number = UNION_CONTROLNET_TYPES.get(type, -1)
|
||||
if type_number >= 0:
|
||||
@@ -28,36 +22,27 @@ class SetUnionControlNetType(io.ComfyNode):
|
||||
else:
|
||||
control_net.set_extra_arg("control_type", [])
|
||||
|
||||
return io.NodeOutput(control_net)
|
||||
return (control_net,)
|
||||
|
||||
set_controlnet_type = execute # TODO: remove
|
||||
|
||||
|
||||
class ControlNetInpaintingAliMamaApply(io.ComfyNode):
|
||||
class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ControlNetInpaintingAliMamaApply",
|
||||
category="conditioning/controlnet",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.ControlNet.Input("control_net"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("image"),
|
||||
io.Mask.Input("mask"),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
|
||||
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
],
|
||||
)
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"control_net": ("CONTROL_NET", ),
|
||||
"vae": ("VAE", ),
|
||||
"image": ("IMAGE", ),
|
||||
"mask": ("MASK", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||
}}
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent) -> io.NodeOutput:
|
||||
FUNCTION = "apply_inpaint_controlnet"
|
||||
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
|
||||
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
|
||||
extra_concat = []
|
||||
if control_net.concat_mask:
|
||||
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
||||
@@ -65,20 +50,11 @@ class ControlNetInpaintingAliMamaApply(io.ComfyNode):
|
||||
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
|
||||
extra_concat = [mask]
|
||||
|
||||
result = nodes.ControlNetApplyAdvanced().apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
|
||||
return io.NodeOutput(result[0], result[1])
|
||||
|
||||
apply_inpaint_controlnet = execute # TODO: remove
|
||||
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
|
||||
|
||||
|
||||
class ControlNetExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
SetUnionControlNetType,
|
||||
ControlNetInpaintingAliMamaApply,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ControlNetExtension:
|
||||
return ControlNetExtension()
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SetUnionControlNetType": SetUnionControlNetType,
|
||||
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
|
||||
}
|
||||
|
||||
@@ -244,8 +244,6 @@ class EasyCacheHolder:
|
||||
self.total_steps_skipped += 1
|
||||
batch_offset = x.shape[0] // len(uuids)
|
||||
for i, uuid in enumerate(uuids):
|
||||
# slice out only what is relevant to this cond
|
||||
batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)]
|
||||
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
||||
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
|
||||
if not self.allow_mismatch:
|
||||
@@ -263,8 +261,9 @@ class EasyCacheHolder:
|
||||
slicing.append(slice(None, dim_u))
|
||||
else:
|
||||
slicing.append(slice(None))
|
||||
batch_slice = batch_slice + slicing
|
||||
x[batch_slice] += self.uuid_cache_diffs[uuid].to(x.device)
|
||||
slicing = [slice(i*batch_offset,(i+1)*batch_offset)] + slicing
|
||||
x = x[slicing]
|
||||
x += self.uuid_cache_diffs[uuid].to(x.device)
|
||||
return x
|
||||
|
||||
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
|
||||
|
||||
86
comfy_extras/nodes_multigpu.py
Normal file
86
comfy_extras/nodes_multigpu.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
from inspect import cleandoc
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
import comfy.multigpu
|
||||
|
||||
|
||||
class MultiGPUWorkUnitsNode:
|
||||
"""
|
||||
Prepares model to have sampling accelerated via splitting work units.
|
||||
|
||||
Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
|
||||
|
||||
Other than those exceptions, this node can be placed in any order.
|
||||
"""
|
||||
|
||||
NodeId = "MultiGPU_WorkUnits"
|
||||
NodeName = "MultiGPU Work Units"
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"max_gpus" : ("INT", {"default": 8, "min": 1, "step": 1}),
|
||||
},
|
||||
"optional": {
|
||||
"gpu_options": ("GPU_OPTIONS",)
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "init_multigpu"
|
||||
CATEGORY = "advanced/multigpu"
|
||||
DESCRIPTION = cleandoc(__doc__)
|
||||
|
||||
def init_multigpu(self, model: ModelPatcher, max_gpus: int, gpu_options: comfy.multigpu.GPUOptionsGroup=None):
|
||||
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, gpu_options, reuse_loaded=True)
|
||||
return (model,)
|
||||
|
||||
class MultiGPUOptionsNode:
|
||||
"""
|
||||
Select the relative speed of GPUs in the special case they have significantly different performance from one another.
|
||||
"""
|
||||
|
||||
NodeId = "MultiGPU_Options"
|
||||
NodeName = "MultiGPU Options"
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"device_index": ("INT", {"default": 0, "min": 0, "max": 64}),
|
||||
"relative_speed": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.01})
|
||||
},
|
||||
"optional": {
|
||||
"gpu_options": ("GPU_OPTIONS",)
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("GPU_OPTIONS",)
|
||||
FUNCTION = "create_gpu_options"
|
||||
CATEGORY = "advanced/multigpu"
|
||||
DESCRIPTION = cleandoc(__doc__)
|
||||
|
||||
def create_gpu_options(self, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup=None):
|
||||
if not gpu_options:
|
||||
gpu_options = comfy.multigpu.GPUOptionsGroup()
|
||||
gpu_options.clone()
|
||||
|
||||
opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed)
|
||||
gpu_options.add(opt)
|
||||
|
||||
return (gpu_options,)
|
||||
|
||||
|
||||
node_list = [
|
||||
MultiGPUWorkUnitsNode,
|
||||
MultiGPUOptionsNode
|
||||
]
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||
|
||||
for node in node_list:
|
||||
NODE_CLASS_MAPPINGS[node.NodeId] = node
|
||||
NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.67"
|
||||
__version__ = "0.3.65"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import importlib.util
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
from comfy.cli_args import args
|
||||
import subprocess
|
||||
|
||||
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
|
||||
@@ -75,9 +75,8 @@ if not args.cuda_malloc:
|
||||
spec.loader.exec_module(module)
|
||||
version = module.__version__
|
||||
|
||||
if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch
|
||||
if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc
|
||||
args.cuda_malloc = cuda_malloc_supported()
|
||||
if int(version[0]) >= 2 and "+cu" in version: #enable by default for torch version 2.0 and up only on cuda torch
|
||||
args.cuda_malloc = cuda_malloc_supported()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
34
execution.py
34
execution.py
@@ -18,7 +18,7 @@ from comfy_execution.caching import (
|
||||
BasicCache,
|
||||
CacheKeySetID,
|
||||
CacheKeySetInputSignature,
|
||||
NullCache,
|
||||
DependencyAwareCache,
|
||||
HierarchicalCache,
|
||||
LRUCache,
|
||||
)
|
||||
@@ -91,13 +91,13 @@ class IsChangedCache:
|
||||
class CacheType(Enum):
|
||||
CLASSIC = 0
|
||||
LRU = 1
|
||||
NONE = 2
|
||||
DEPENDENCY_AWARE = 2
|
||||
|
||||
|
||||
class CacheSet:
|
||||
def __init__(self, cache_type=None, cache_size=None):
|
||||
if cache_type == CacheType.NONE:
|
||||
self.init_null_cache()
|
||||
if cache_type == CacheType.DEPENDENCY_AWARE:
|
||||
self.init_dependency_aware_cache()
|
||||
logging.info("Disabling intermediate node cache.")
|
||||
elif cache_type == CacheType.LRU:
|
||||
if cache_size is None:
|
||||
@@ -120,12 +120,11 @@ class CacheSet:
|
||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_null_cache(self):
|
||||
self.outputs = NullCache()
|
||||
#The UI cache is expected to be iterable at the end of each workflow
|
||||
#so it must cache at least a full workflow. Use Heirachical
|
||||
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.objects = NullCache()
|
||||
# only hold cached items while the decendents have not executed
|
||||
def init_dependency_aware_cache(self):
|
||||
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
|
||||
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
|
||||
self.objects = DependencyAwareCache(CacheKeySetID)
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
result = {
|
||||
@@ -136,7 +135,7 @@ class CacheSet:
|
||||
|
||||
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
|
||||
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
|
||||
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
||||
if is_v3:
|
||||
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
|
||||
@@ -154,10 +153,10 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
||||
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if execution_list is None:
|
||||
if outputs is None:
|
||||
mark_missing()
|
||||
continue # This might be a lazily-evaluated input
|
||||
cached_output = execution_list.get_output_cache(input_unique_id, unique_id)
|
||||
cached_output = outputs.get(input_unique_id)
|
||||
if cached_output is None:
|
||||
mark_missing()
|
||||
continue
|
||||
@@ -406,7 +405,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
cached_output = caches.ui.get(unique_id) or {}
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||
get_progress_state().finish_progress(unique_id)
|
||||
execution_list.cache_update(unique_id, caches.outputs.get(unique_id))
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
input_data_all = None
|
||||
@@ -436,7 +434,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
for r in result:
|
||||
if is_link(r):
|
||||
source_node, source_output = r[0], r[1]
|
||||
node_output = execution_list.get_output_cache(source_node, unique_id)[source_output]
|
||||
node_output = caches.outputs.get(source_node)[source_output]
|
||||
for o in node_output:
|
||||
resolved_output.append(o)
|
||||
|
||||
@@ -448,7 +446,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
has_subgraph = False
|
||||
else:
|
||||
get_progress_state().start_progress(unique_id)
|
||||
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
||||
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
@@ -551,15 +549,11 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
subcache.clean_unused()
|
||||
for node_id in new_output_ids:
|
||||
execution_list.add_node(node_id)
|
||||
execution_list.cache_link(node_id, unique_id)
|
||||
for link in new_output_links:
|
||||
execution_list.add_strong_link(link[0], link[1], unique_id)
|
||||
pending_subgraph_results[unique_id] = cached_outputs
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
|
||||
caches.outputs.set(unique_id, output_data)
|
||||
execution_list.cache_update(unique_id, output_data)
|
||||
|
||||
except comfy.model_management.InterruptProcessingException as iex:
|
||||
logging.info("Processing interrupted")
|
||||
|
||||
|
||||
2
main.py
2
main.py
@@ -173,7 +173,7 @@ def prompt_worker(q, server_instance):
|
||||
if args.cache_lru > 0:
|
||||
cache_type = execution.CacheType.LRU
|
||||
elif args.cache_none:
|
||||
cache_type = execution.CacheType.NONE
|
||||
cache_type = execution.CacheType.DEPENDENCY_AWARE
|
||||
|
||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
|
||||
last_gc_collect = 0
|
||||
|
||||
2
nodes.py
2
nodes.py
@@ -2304,6 +2304,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_mahiro.py",
|
||||
"nodes_lt.py",
|
||||
"nodes_hooks.py",
|
||||
"nodes_multigpu.py",
|
||||
"nodes_load_3d.py",
|
||||
"nodes_cosmos.py",
|
||||
"nodes_video.py",
|
||||
@@ -2349,7 +2350,6 @@ async def init_builtin_api_nodes():
|
||||
"nodes_kling.py",
|
||||
"nodes_bfl.py",
|
||||
"nodes_bytedance.py",
|
||||
"nodes_ltxv.py",
|
||||
"nodes_luma.py",
|
||||
"nodes_recraft.py",
|
||||
"nodes_pixverse.py",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.67"
|
||||
version = "0.3.65"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
@@ -50,8 +50,6 @@ messages_control.disable = [
|
||||
"too-many-branches",
|
||||
"too-many-locals",
|
||||
"too-many-arguments",
|
||||
"too-many-return-statements",
|
||||
"too-many-nested-blocks",
|
||||
"duplicate-code",
|
||||
"abstract-method",
|
||||
"superfluous-parens",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.28.8
|
||||
comfyui-workflow-templates==0.2.4
|
||||
comfyui-frontend-package==1.28.6
|
||||
comfyui-workflow-templates==0.1.95
|
||||
comfyui-embedded-docs==0.3.0
|
||||
torch
|
||||
torchsde
|
||||
|
||||
27
server.py
27
server.py
@@ -35,7 +35,6 @@ from comfy_api.internal import _ComfyNodeInternal
|
||||
from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from app.subgraph_manager import SubgraphManager
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from protocol import BinaryEventTypes
|
||||
@@ -49,28 +48,6 @@ async def send_socket_catch_exception(function, message):
|
||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
|
||||
logging.warning("send error: {}".format(err))
|
||||
|
||||
# Track deprecated paths that have been warned about to only warn once per file
|
||||
_deprecated_paths_warned = set()
|
||||
|
||||
@web.middleware
|
||||
async def deprecation_warning(request: web.Request, handler):
|
||||
"""Middleware to warn about deprecated frontend API paths"""
|
||||
path = request.path
|
||||
|
||||
if path.startswith("/scripts/ui") or path.startswith("/extensions/core/"):
|
||||
# Only warn once per unique file path
|
||||
if path not in _deprecated_paths_warned:
|
||||
_deprecated_paths_warned.add(path)
|
||||
logging.warning(
|
||||
f"[DEPRECATION WARNING] Detected import of deprecated legacy API: {path}. "
|
||||
f"This is likely caused by a custom node extension using outdated APIs. "
|
||||
f"Please update your extensions or contact the extension author for an updated version."
|
||||
)
|
||||
|
||||
response: web.Response = await handler(request)
|
||||
return response
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def compress_body(request: web.Request, handler):
|
||||
accept_encoding = request.headers.get("Accept-Encoding", "")
|
||||
@@ -174,7 +151,6 @@ class PromptServer():
|
||||
self.user_manager = UserManager()
|
||||
self.model_file_manager = ModelFileManager()
|
||||
self.custom_node_manager = CustomNodeManager()
|
||||
self.subgraph_manager = SubgraphManager()
|
||||
self.internal_routes = InternalRoutes(self)
|
||||
self.supports = ["custom_nodes_from_web"]
|
||||
self.prompt_queue = execution.PromptQueue(self)
|
||||
@@ -183,7 +159,7 @@ class PromptServer():
|
||||
self.client_session:Optional[aiohttp.ClientSession] = None
|
||||
self.number = 0
|
||||
|
||||
middlewares = [cache_control, deprecation_warning]
|
||||
middlewares = [cache_control]
|
||||
if args.enable_compress_response_body:
|
||||
middlewares.append(compress_body)
|
||||
|
||||
@@ -821,7 +797,6 @@ class PromptServer():
|
||||
self.user_manager.add_routes(self.routes)
|
||||
self.model_file_manager.add_routes(self.routes)
|
||||
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
||||
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
|
||||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||
|
||||
# Prefix every route with /api for easier matching for delegation.
|
||||
|
||||
@@ -152,12 +152,12 @@ class TestExecution:
|
||||
# Initialize server and client
|
||||
#
|
||||
@fixture(scope="class", autouse=True, params=[
|
||||
{ "extra_args" : [], "should_cache_results" : True },
|
||||
{ "extra_args" : ["--cache-lru", 0], "should_cache_results" : True },
|
||||
{ "extra_args" : ["--cache-lru", 100], "should_cache_results" : True },
|
||||
{ "extra_args" : ["--cache-none"], "should_cache_results" : False },
|
||||
# (use_lru, lru_size)
|
||||
(False, 0),
|
||||
(True, 0),
|
||||
(True, 100),
|
||||
])
|
||||
def server(self, args_pytest, request):
|
||||
def _server(self, args_pytest, request):
|
||||
# Start server
|
||||
pargs = [
|
||||
'python','main.py',
|
||||
@@ -167,10 +167,12 @@ class TestExecution:
|
||||
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
|
||||
'--cpu',
|
||||
]
|
||||
pargs += [ str(param) for param in request.param["extra_args"] ]
|
||||
use_lru, lru_size = request.param
|
||||
if use_lru:
|
||||
pargs += ['--cache-lru', str(lru_size)]
|
||||
print("Running server with args:", pargs) # noqa: T201
|
||||
p = subprocess.Popen(pargs)
|
||||
yield request.param
|
||||
yield
|
||||
p.kill()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -191,7 +193,7 @@ class TestExecution:
|
||||
return comfy_client
|
||||
|
||||
@fixture(scope="class", autouse=True)
|
||||
def shared_client(self, args_pytest, server):
|
||||
def shared_client(self, args_pytest, _server):
|
||||
client = self.start_client(args_pytest["listen"], args_pytest["port"])
|
||||
yield client
|
||||
del client
|
||||
@@ -223,7 +225,7 @@ class TestExecution:
|
||||
assert result.did_run(mask)
|
||||
assert result.did_run(lazy_mix)
|
||||
|
||||
def test_full_cache(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||
def test_full_cache(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||
@@ -235,12 +237,9 @@ class TestExecution:
|
||||
client.run(g)
|
||||
result2 = client.run(g)
|
||||
for node_id, node in g.nodes.items():
|
||||
if server["should_cache_results"]:
|
||||
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
|
||||
else:
|
||||
assert result2.did_run(node), f"Node {node_id} was cached, but should have been run"
|
||||
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
|
||||
|
||||
def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||
def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||
@@ -252,12 +251,8 @@ class TestExecution:
|
||||
client.run(g)
|
||||
mask.inputs['value'] = 0.4
|
||||
result2 = client.run(g)
|
||||
if server["should_cache_results"]:
|
||||
assert not result2.did_run(input1), "Input1 should have been cached"
|
||||
assert not result2.did_run(input2), "Input2 should have been cached"
|
||||
else:
|
||||
assert result2.did_run(input1), "Input1 should have been rerun"
|
||||
assert result2.did_run(input2), "Input2 should have been rerun"
|
||||
assert not result2.did_run(input1), "Input1 should have been cached"
|
||||
assert not result2.did_run(input2), "Input2 should have been cached"
|
||||
|
||||
def test_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
@@ -416,7 +411,7 @@ class TestExecution:
|
||||
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
||||
client.run(g)
|
||||
|
||||
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
# Creating the nodes in this specific order previously caused a bug
|
||||
save = g.node("SaveImage")
|
||||
@@ -432,10 +427,7 @@ class TestExecution:
|
||||
result3 = client.run(g)
|
||||
result4 = client.run(g)
|
||||
assert result1.did_run(is_changed), "is_changed should have been run"
|
||||
if server["should_cache_results"]:
|
||||
assert not result2.did_run(is_changed), "is_changed should have been cached"
|
||||
else:
|
||||
assert result2.did_run(is_changed), "is_changed should have been re-run"
|
||||
assert not result2.did_run(is_changed), "is_changed should have been cached"
|
||||
assert result3.did_run(is_changed), "is_changed should have been re-run"
|
||||
assert result4.did_run(is_changed), "is_changed should not have been cached"
|
||||
|
||||
@@ -522,7 +514,7 @@ class TestExecution:
|
||||
assert len(images2) == 1, "Should have 1 image"
|
||||
|
||||
# This tests that only constant outputs are used in the call to `IS_CHANGED`
|
||||
def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||
def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
|
||||
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5)
|
||||
@@ -538,11 +530,7 @@ class TestExecution:
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 1, "Should have 1 image"
|
||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||
if server["should_cache_results"]:
|
||||
assert not result.did_run(test_node), "The execution should have been cached"
|
||||
else:
|
||||
assert result.did_run(test_node), "The execution should have been re-run"
|
||||
|
||||
assert not result.did_run(test_node), "The execution should have been cached"
|
||||
|
||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
|
||||
Reference in New Issue
Block a user