Compare commits

..

25 Commits

Author SHA1 Message Date
Jedrzej Kosinski
386e854aab Merge branch 'master' into flipflop-stream 2025-10-28 15:08:27 -07:00
Jedrzej Kosinski
61133af772 Add '--flipflop-offload' startup argument 2025-10-13 21:10:44 -07:00
Jedrzej Kosinski
586a8de8da Merge branch 'master' into flipflop-stream 2025-10-13 21:04:37 -07:00
Jedrzej Kosinski
5329180fce Made flipflop consider partial_unload, partial_offload, and add flip+flop to mem counters 2025-10-03 16:21:01 -07:00
Jedrzej Kosinski
0fdd327c2f Merge branch 'master' into flipflop-stream 2025-10-03 14:32:56 -07:00
Jedrzej Kosinski
ee01002e63 Add flipflop support to (base) WAN, fix issue with applying loras to flipflop weights being done on CPU instead of GPU, left some timing functions as the lora application time could use some reduction 2025-10-02 22:02:50 -07:00
Jedrzej Kosinski
831c3cf05e Add a temporary workaround for odd amount of blocks not producing expected results 2025-10-02 20:29:11 -07:00
Jedrzej Kosinski
0d8e8abd90 Default ro smaller blocks getting flipflopped first 2025-10-02 18:00:21 -07:00
Jedrzej Kosinski
d5001ed90e Make flux support flipflop 2025-10-02 17:53:22 -07:00
Jedrzej Kosinski
8d7b22b720 Fixed FlipFlipModule.execute_blocks having hardcoded strings from Qwen 2025-10-02 17:49:43 -07:00
Jedrzej Kosinski
6d3ec9fcf3 Simplified flipflop setup by adding FlipFlopModule.execute_blocks helper 2025-10-02 16:46:37 -07:00
Jedrzej Kosinski
c4420b6a41 Change log string slightly 2025-10-02 15:34:35 -07:00
Jedrzej Kosinski
a282586995 Merge branch 'master' into flipflop-stream 2025-10-02 15:03:26 -07:00
Jedrzej Kosinski
0df61b5032 Fix improper index slicing for flipflop get blocks, add extra log message 2025-10-01 21:21:36 -07:00
Jedrzej Kosinski
7c896c5567 Initial automatic support for flipflop within ModelPatcher - only Qwen Image diffusion_model uses FlipFlopModule currently 2025-10-01 20:13:50 -07:00
Jedrzej Kosinski
ec156e72eb Merge branch 'master' into flipflop-stream 2025-09-30 23:08:37 -07:00
Jedrzej Kosinski
01f4512bf8 In-progress commit on making flipflop async weight streaming native, made loaded partially/loaded completely log messages have labels because having to memorize their meaning for dev work is annoying 2025-09-30 23:08:08 -07:00
Jedrzej Kosinski
d0bd221495 Merge branch 'master' into flipflop-stream 2025-09-29 22:49:38 -07:00
Jedrzej Kosinski
8a8162e8da Fix percentage logic, begin adding elements to ModelPatcher to track flip flop compatibility 2025-09-29 22:49:12 -07:00
Jedrzej Kosinski
ff789c8beb Merge branch 'master' into flipflop-stream 2025-09-29 16:09:51 -07:00
Jedrzej Kosinski
0e966dcf85 Merge branch 'master' into flipflop-stream 2025-09-27 21:13:26 -07:00
Jedrzej Kosinski
6b240b0bce Refactored old flip flop into a new implementation that allows for controlling the percentage of blocks getting flip flopped, converted nodes to v3 schema 2025-09-25 22:41:41 -07:00
Jedrzej Kosinski
f9fbf902d5 Added missing Qwen block params, further subdivided blocks function 2025-09-25 17:49:39 -07:00
Jedrzej Kosinski
f083720eb4 Refactored FlipFlopTransformer.__call__ to fully separate out actions between flip and flop 2025-09-25 16:16:51 -07:00
Jedrzej Kosinski
84e73f2aa5 Brought over flip flop prototype from contentis' fork, limiting it to only Qwen to ease the process of adapting it to be a native feature 2025-09-25 16:15:46 -07:00
62 changed files with 3921 additions and 2086 deletions

View File

@@ -8,15 +8,13 @@ body:
Before submitting a **Bug Report**, please ensure the following:
- **1:** You are running the latest version of ComfyUI.
- **2:** You have your ComfyUI logs and relevant workflow on hand and will post them in this bug report.
- **2:** You have looked at the existing bug reports and made sure this isn't already reported.
- **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
`--disable-all-custom-nodes` command line argument. If you have custom node try updating them to the latest version.
`--disable-all-custom-nodes` command line argument.
- **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
## Very Important
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
- type: checkboxes
id: custom-nodes-test
attributes:

View File

@@ -112,11 +112,10 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
## Release Process
ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- Releases a new stable version (e.g., v0.7.0) roughly every week.
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
- Releases a new stable version (e.g., v0.7.0)
- Serves as the foundation for the desktop release
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**

View File

@@ -10,8 +10,7 @@ import importlib
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import Dict, TypedDict, Optional
from aiohttp import web
from typing import TypedDict, Optional
from importlib.metadata import version
import requests
@@ -258,54 +257,7 @@ comfyui-frontend-package is not installed.
sys.exit(-1)
@classmethod
def template_asset_map(cls) -> Optional[Dict[str, str]]:
"""Return a mapping of template asset names to their absolute paths."""
try:
from comfyui_workflow_templates import (
get_asset_path,
iter_templates,
)
except ImportError:
logging.error(
f"""
********** ERROR ***********
comfyui-workflow-templates is not installed.
{frontend_install_warning_message()}
********** ERROR ***********
""".strip()
)
return None
try:
template_entries = list(iter_templates())
except Exception as exc:
logging.error(f"Failed to enumerate workflow templates: {exc}")
return None
asset_map: Dict[str, str] = {}
try:
for entry in template_entries:
for asset in entry.assets:
asset_map[asset.filename] = get_asset_path(
entry.template_id, asset.filename
)
except Exception as exc:
logging.error(f"Failed to resolve template asset paths: {exc}")
return None
if not asset_map:
logging.error("No workflow template assets found. Did the packages install correctly?")
return None
return asset_map
@classmethod
def legacy_templates_path(cls) -> Optional[str]:
"""Return the legacy templates directory shipped inside the meta package."""
def templates_path(cls) -> str:
try:
import comfyui_workflow_templates
@@ -324,7 +276,6 @@ comfyui-workflow-templates is not installed.
********** ERROR ***********
""".strip()
)
return None
@classmethod
def embedded_docs_path(cls) -> str:
@@ -441,17 +392,3 @@ comfyui-workflow-templates is not installed.
logging.info("Falling back to the default frontend.")
check_frontend_version()
return cls.default_frontend_path()
@classmethod
def template_asset_handler(cls):
assets = cls.template_asset_map()
if not assets:
return None
async def serve_template(request: web.Request) -> web.StreamResponse:
rel_path = request.match_info.get("path", "")
target = assets.get(rel_path)
if target is None:
raise web.HTTPNotFound()
return web.FileResponse(target)
return serve_template

View File

@@ -105,7 +105,6 @@ cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
@@ -133,6 +132,8 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
parser.add_argument("--flipflop-offload", action="store_true", help="Use async flipflop weight offloading for supported DiT models.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
@@ -146,9 +147,7 @@ class PerformanceFeature(enum.Enum):
CublasOps = "cublas_ops"
AutoTune = "autotune"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")

View File

@@ -310,13 +310,11 @@ class ControlLoraOps:
self.bias = None
def forward(self, input):
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
weight, bias = comfy.ops.cast_bias_weight(self, input)
if self.up is not None:
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else:
x = torch.nn.functional.linear(input, weight, bias)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
return torch.nn.functional.linear(input, weight, bias)
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(
@@ -352,13 +350,12 @@ class ControlLoraOps:
def forward(self, input):
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
weight, bias = comfy.ops.cast_bias_weight(self, input)
if self.up is not None:
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else:
x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options

View File

@@ -0,0 +1,200 @@
from __future__ import annotations
import torch
import copy
import comfy.model_management
class FlipFlopModule(torch.nn.Module):
def __init__(self, block_types: tuple[str, ...], enable_flipflop: bool = True):
super().__init__()
self.block_types = block_types
self.enable_flipflop = enable_flipflop
self.flipflop: dict[str, FlipFlopHolder] = {}
self.block_info: dict[str, tuple[int, int]] = {}
self.flipflop_prefixes: list[str] = []
def setup_flipflop_holders(self, block_info: dict[str, tuple[int, int]], flipflop_prefixes: list[str], load_device: torch.device, offload_device: torch.device):
for block_type, (flipflop_blocks, total_blocks) in block_info.items():
if block_type in self.flipflop:
continue
self.flipflop[block_type] = FlipFlopHolder(getattr(self, block_type)[total_blocks-flipflop_blocks:], flipflop_blocks, total_blocks, load_device, offload_device)
self.block_info[block_type] = (flipflop_blocks, total_blocks)
self.flipflop_prefixes = flipflop_prefixes.copy()
def init_flipflop_block_copies(self, device: torch.device) -> int:
memory_freed = 0
for holder in self.flipflop.values():
memory_freed += holder.init_flipflop_block_copies(device)
return memory_freed
def clean_flipflop_holders(self):
memory_freed = 0
for block_type in list(self.flipflop.keys()):
memory_freed += self.flipflop[block_type].clean_flipflop_blocks()
del self.flipflop[block_type]
self.block_info = {}
self.flipflop_prefixes = []
return memory_freed
def get_all_blocks(self, block_type: str) -> list[torch.nn.Module]:
return getattr(self, block_type)
def get_blocks(self, block_type: str) -> torch.nn.ModuleList:
if block_type not in self.block_types:
raise ValueError(f"Block type {block_type} not found in {self.block_types}")
if block_type in self.flipflop:
return getattr(self, block_type)[:self.flipflop[block_type].i_offset]
return getattr(self, block_type)
def get_all_block_module_sizes(self, reverse_sort_by_size: bool = False) -> list[tuple[str, int]]:
'''
Returns a list of (block_type, size) sorted by size.
If reverse_sort_by_size is True, the list is sorted by size in reverse order.
'''
sizes = [(block_type, self.get_block_module_size(block_type)) for block_type in self.block_types]
sizes.sort(key=lambda x: x[1], reverse=reverse_sort_by_size)
return sizes
def get_block_module_size(self, block_type: str) -> int:
return comfy.model_management.module_size(getattr(self, block_type)[0])
def execute_blocks(self, block_type: str, func, out: torch.Tensor | tuple[torch.Tensor,...], *args, **kwargs):
# execute blocks, supporting both single and double (or higher) block types
if isinstance(out, torch.Tensor):
out = (out,)
for i, block in enumerate(self.get_blocks(block_type)):
out = func(i, block, *out, *args, **kwargs)
if isinstance(out, torch.Tensor):
out = (out,)
if block_type in self.flipflop:
holder = self.flipflop[block_type]
with holder.context() as ctx:
for i, block in enumerate(holder.blocks):
out = ctx(func, i, block, *out, *args, **kwargs)
if isinstance(out, torch.Tensor):
out = (out,)
if len(out) == 1:
out = out[0]
return out
class FlipFlopContext:
def __init__(self, holder: FlipFlopHolder):
# NOTE: there is a bug when there are an odd number of blocks to flipflop.
# Worked around right now by always making sure it will be even, but need to resolve.
self.holder = holder
self.reset()
def reset(self):
self.num_blocks = len(self.holder.blocks)
self.first_flip = True
self.first_flop = True
self.last_flip = False
self.last_flop = False
def __enter__(self):
self.reset()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.holder.compute_stream.record_event(self.holder.cpy_end_event)
def do_flip(self, func, i: int, _, *args, **kwargs):
# flip
self.holder.compute_stream.wait_event(self.holder.cpy_end_event)
with torch.cuda.stream(self.holder.compute_stream):
out = func(i+self.holder.i_offset, self.holder.flip, *args, **kwargs)
self.holder.event_flip.record(self.holder.compute_stream)
# while flip executes, queue flop to copy to its next block
next_flop_i = i + 1
if next_flop_i >= self.num_blocks:
next_flop_i = next_flop_i - self.num_blocks
self.last_flip = True
if not self.first_flip:
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.blocks[next_flop_i].state_dict(), self.holder.event_flop, self.holder.cpy_end_event)
if self.last_flip:
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.blocks[0].state_dict(), cpy_start_event=self.holder.event_flip)
self.first_flip = False
return out
def do_flop(self, func, i: int, _, *args, **kwargs):
# flop
if not self.first_flop:
self.holder.compute_stream.wait_event(self.holder.cpy_end_event)
with torch.cuda.stream(self.holder.compute_stream):
out = func(i+self.holder.i_offset, self.holder.flop, *args, **kwargs)
self.holder.event_flop.record(self.holder.compute_stream)
# while flop executes, queue flip to copy to its next block
next_flip_i = i + 1
if next_flip_i >= self.num_blocks:
next_flip_i = next_flip_i - self.num_blocks
self.last_flop = True
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.blocks[next_flip_i].state_dict(), self.holder.event_flip, self.holder.cpy_end_event)
if self.last_flop:
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.blocks[1].state_dict(), cpy_start_event=self.holder.event_flop)
self.first_flop = False
return out
@torch.no_grad()
def __call__(self, func, i: int, block: torch.nn.Module, *args, **kwargs):
# flips are even indexes, flops are odd indexes
if i % 2 == 0:
return self.do_flip(func, i, block, *args, **kwargs)
else:
return self.do_flop(func, i, block, *args, **kwargs)
class FlipFlopHolder:
def __init__(self, blocks: list[torch.nn.Module], flip_amount: int, total_amount: int, load_device: torch.device, offload_device: torch.device):
self.load_device = load_device
self.offload_device = offload_device
self.blocks = blocks
self.flip_amount = flip_amount
self.total_amount = total_amount
# NOTE: used to make sure block indexes passed into block functions match expected patch indexes
self.i_offset = total_amount - flip_amount
self.block_module_size = 0
if len(self.blocks) > 0:
self.block_module_size = comfy.model_management.module_size(self.blocks[0])
self.flip: torch.nn.Module = None
self.flop: torch.nn.Module = None
self.compute_stream = torch.cuda.default_stream(self.load_device)
self.cpy_stream = torch.cuda.Stream(self.load_device)
self.event_flip = torch.cuda.Event(enable_timing=False)
self.event_flop = torch.cuda.Event(enable_timing=False)
self.cpy_end_event = torch.cuda.Event(enable_timing=False)
# INIT - is this actually needed?
self.compute_stream.record_event(self.cpy_end_event)
def _copy_state_dict(self, dst, src, cpy_start_event: torch.cuda.Event=None, cpy_end_event: torch.cuda.Event=None):
if cpy_start_event:
self.cpy_stream.wait_event(cpy_start_event)
with torch.cuda.stream(self.cpy_stream):
for k, v in src.items():
dst[k].copy_(v, non_blocking=True)
if cpy_end_event:
cpy_end_event.record(self.cpy_stream)
def context(self):
return FlipFlopContext(self)
def init_flipflop_block_copies(self, load_device: torch.device) -> int:
self.flip = copy.deepcopy(self.blocks[0]).to(device=load_device)
self.flop = copy.deepcopy(self.blocks[1]).to(device=load_device)
return comfy.model_management.module_size(self.flip) + comfy.model_management.module_size(self.flop)
def clean_flipflop_blocks(self) -> int:
memory_freed = 0
memory_freed += comfy.model_management.module_size(self.flip)
memory_freed += comfy.model_management.module_size(self.flop)
del self.flip
del self.flop
self.flip = None
self.flop = None
return memory_freed

View File

@@ -195,8 +195,8 @@ class DoubleStreamBlock(nn.Module):
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
# calculate the txt bloks
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)

View File

@@ -7,7 +7,15 @@ import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
q, k = apply_rope(q, k, pe)
q_shape = q.shape
k_shape = k.shape
if pe is not None:
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
return x

View File

@@ -7,6 +7,7 @@ from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.ldm.common_dit
import comfy.patcher_extension
from comfy.ldm.flipflop_transformer import FlipFlopModule
from .layers import (
DoubleStreamBlock,
@@ -35,13 +36,13 @@ class FluxParams:
guidance_embed: bool
class Flux(nn.Module):
class Flux(FlipFlopModule):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
super().__init__(("double_blocks", "single_blocks"))
self.dtype = dtype
params = FluxParams(**kwargs)
self.params = params
@@ -89,6 +90,72 @@ class Flux(nn.Module):
if final_layer:
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
def indiv_double_block_fwd(self, i, block, img, txt, vec, pe, attn_mask, control, blocks_replace, transformer_options):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img,
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask,
transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img[:, :add.shape[1]] += add
return img, txt
def indiv_single_block_fwd(self, i, block, img, txt, vec, pe, attn_mask, control, blocks_replace, transformer_options):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
return img
def forward_orig(
self,
img: Tensor,
@@ -136,74 +203,16 @@ class Flux(nn.Module):
pe = None
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img,
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask,
transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img[:, :add.shape[1]] += add
# execute double blocks
img, txt = self.execute_blocks("double_blocks", self.indiv_double_block_fwd, (img, txt), vec, pe, attn_mask, control, blocks_replace, transformer_options)
if img.dtype == torch.float16:
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
# execute single blocks
img = self.execute_blocks("single_blocks", self.indiv_single_block_fwd, img, txt, vec, pe, attn_mask, control, blocks_replace, transformer_options)
img = img[:, txt.shape[1] :, ...]

View File

@@ -3,11 +3,12 @@ from torch import nn
import comfy.patcher_extension
import comfy.ldm.modules.attention
import comfy.ldm.common_dit
from einops import rearrange
import math
from typing import Dict, Optional, Tuple
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from comfy.ldm.flux.math import apply_rope1
def get_timestep_embedding(
timesteps: torch.Tensor,
@@ -237,6 +238,20 @@ class FeedForward(nn.Module):
return self.net(x)
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
cos_freqs = freqs_cis[0]
sin_freqs = freqs_cis[1]
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
t1, t2 = t_dup.unbind(dim=-1)
t_dup = torch.stack((-t2, t1), dim=-1)
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
return out
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
@@ -266,8 +281,8 @@ class CrossAttention(nn.Module):
k = self.k_norm(k)
if pe is not None:
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
q = apply_rotary_emb(q, pe)
k = apply_rotary_emb(k, pe)
if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
@@ -291,17 +306,12 @@ class BasicTransformerBlock(nn.Module):
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
attn1_input = comfy.ldm.common_dit.rms_norm(x)
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
x.addcmul_(attn1_input, gate_msa)
del attn1_input
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
y = comfy.ldm.common_dit.rms_norm(x)
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
x.addcmul_(self.ff(y), gate_mlp)
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
x += self.ff(y) * gate_mlp
return x
@@ -317,35 +327,41 @@ def get_fractional_positions(indices_grid, max_pos):
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
dtype = torch.float32
device = indices_grid.device
dtype = torch.float32 #self.dtype
# Get fractional positions and compute frequency indices
fractional_positions = get_fractional_positions(indices_grid, max_pos)
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
# Compute frequencies and apply cos/sin
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
start = 1
end = theta
device = fractional_positions.device
# Pad if dim is not divisible by 6
indices = theta ** (
torch.linspace(
math.log(start, theta),
math.log(end, theta),
dim // 6,
device=device,
dtype=dtype,
)
)
indices = indices.to(dtype=dtype)
indices = indices * math.pi / 2
freqs = (
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
.transpose(-1, -2)
.flatten(2)
)
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
if dim % 6 != 0:
padding_size = dim % 6
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
freqs_cis = torch.stack([
torch.stack([cos_vals, -sin_vals], dim=-1),
torch.stack([sin_vals, cos_vals], dim=-1)
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
return freqs_cis
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
class LTXVModel(torch.nn.Module):
@@ -485,7 +501,7 @@ class LTXVModel(torch.nn.Module):
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x)
# Modulation
x = torch.addcmul(x, x, scale).add_(shift)
x = x * (1 + scale) + shift
x = self.proj_out(x)
x = self.patchifier.unpatchify(

View File

@@ -522,7 +522,7 @@ class NextDiT(nn.Module):
max_cap_len = max(l_effective_cap_len)
max_img_len = max(l_effective_img_len)
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device)
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
@@ -531,22 +531,10 @@ class NextDiT(nn.Module):
H_tokens, W_tokens = H // pH, W // pW
assert H_tokens * W_tokens == img_len
rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids

View File

@@ -44,7 +44,7 @@ class QwenImageControlNetModel(QwenImageTransformer2DModel):
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)

View File

@@ -5,12 +5,13 @@ import torch.nn.functional as F
from typing import Optional, Tuple
from einops import repeat
from comfy.ldm.flipflop_transformer import FlipFlopModule
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
import comfy.ldm.common_dit
import comfy.patcher_extension
from comfy.ldm.flux.math import apply_rope1
import comfy.ops
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
@@ -135,34 +136,33 @@ class Attention(nn.Module):
image_rotary_emb: Optional[torch.Tensor] = None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.shape[0]
seq_img = hidden_states.shape[1]
seq_txt = encoder_hidden_states.shape[1]
# Project and reshape to BHND format (batch, heads, seq, dim)
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key)
joint_query = torch.cat([txt_query, img_query], dim=2)
joint_key = torch.cat([txt_key, img_key], dim=2)
joint_value = torch.cat([txt_value, img_value], dim=2)
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attention_mask, transformer_options=transformer_options,
skip_reshape=True)
joint_query = joint_query.flatten(start_dim=2)
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]
@@ -285,7 +285,7 @@ class LastLayer(nn.Module):
return x
class QwenImageTransformer2DModel(nn.Module):
class QwenImageTransformer2DModel(FlipFlopModule):
def __init__(
self,
patch_size: int = 2,
@@ -302,9 +302,9 @@ class QwenImageTransformer2DModel(nn.Module):
final_layer=True,
dtype=None,
device=None,
operations=None,
operations: comfy.ops.disable_weight_init=None,
):
super().__init__()
super().__init__(block_types=("transformer_blocks",))
self.dtype = dtype
self.patch_size = patch_size
self.in_channels = in_channels
@@ -368,6 +368,40 @@ class QwenImageTransformer2DModel(nn.Module):
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
def indiv_block_fwd(self, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
hidden_states[:, :add.shape[1]] += add
return hidden_states, encoder_hidden_states
def _forward(
self,
x,
@@ -415,7 +449,7 @@ class QwenImageTransformer2DModel(nn.Module):
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states)
@@ -435,37 +469,8 @@ class QwenImageTransformer2DModel(nn.Module):
patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
hidden_states[:, :add.shape[1]] += add
out = (hidden_states, encoder_hidden_states)
hidden_states, encoder_hidden_states = self.execute_blocks("transformer_blocks", self.indiv_block_fwd, out, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)

View File

@@ -7,6 +7,7 @@ import torch.nn as nn
from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flipflop_transformer import FlipFlopModule
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope1
import comfy.ldm.common_dit
@@ -232,7 +233,6 @@ class WanAttentionBlock(nn.Module):
# assert e[0].dtype == torch.float32
# self-attention
x = x.contiguous() # otherwise implicit in LayerNorm
y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs, transformer_options=transformer_options)
@@ -385,7 +385,7 @@ class MLPProj(torch.nn.Module):
return clip_extra_context_tokens
class WanModel(torch.nn.Module):
class WanModel(FlipFlopModule):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
@@ -413,6 +413,7 @@ class WanModel(torch.nn.Module):
device=None,
dtype=None,
operations=None,
enable_flipflop=True,
):
r"""
Initialize the diffusion model backbone.
@@ -450,7 +451,7 @@ class WanModel(torch.nn.Module):
Epsilon value for normalization layers
"""
super().__init__()
super().__init__(block_types=("blocks",), enable_flipflop=enable_flipflop)
self.dtype = dtype
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
@@ -507,6 +508,18 @@ class WanModel(torch.nn.Module):
else:
self.ref_conv = None
def indiv_block_fwd(self, i, block, x, e0, freqs, context, context_img_len, blocks_replace, transformer_options):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
return x
def forward_orig(
self,
x,
@@ -568,16 +581,8 @@ class WanModel(torch.nn.Module):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
# execute blocks
x = self.execute_blocks("blocks", self.indiv_block_fwd, x, e0, freqs, context, context_img_len, blocks_replace, transformer_options)
# head
x = self.head(x, e)
@@ -589,7 +594,7 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes)
return x
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
@@ -602,22 +607,10 @@ class WanModel(torch.nn.Module):
if steps_w is None:
steps_w = w_len
h_start = 0
w_start = 0
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
t_start += rope_options.get("shift_t", 0.0)
h_start += rope_options.get("shift_y", 0.0)
w_start += rope_options.get("shift_x", 0.0)
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
freqs = self.rope_embedder(img_ids).movedim(1, 2)
@@ -643,7 +636,7 @@ class WanModel(torch.nn.Module):
if self.ref_conv is not None and "reference_latent" in kwargs:
t_len += 1
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):
@@ -701,7 +694,7 @@ class VaceWanModel(WanModel):
operations=None,
):
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
# Vace
@@ -821,7 +814,7 @@ class CameraWanModel(WanModel):
else:
model_type = 't2v'
super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
@@ -1224,7 +1217,7 @@ class WanModel_S2V(WanModel):
operations=None,
):
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations)
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype)
@@ -1524,7 +1517,7 @@ class HumoWanModel(WanModel):
operations=None,
):
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations)

View File

@@ -426,7 +426,7 @@ class AnimateWanModel(WanModel):
operations=None,
):
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations, enable_flipflop=False)
self.pose_patch_embedding = operations.Conv3d(
16, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype

View File

@@ -1006,6 +1006,8 @@ def force_channels_last():
#TODO
return False
def flipflop_enabled():
return args.flipflop_offload
STREAMS = {}
NUM_STREAMS = 1
@@ -1013,16 +1015,6 @@ if args.async_offload:
NUM_STREAMS = 2
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
def current_stream(device):
if device is None:
return None
if is_device_cuda(device):
return torch.cuda.current_stream()
elif is_device_xpu(device):
return torch.xpu.current_stream()
else:
return None
stream_counters = {}
def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
@@ -1031,17 +1023,21 @@ def get_offload_stream(device):
if device in STREAMS:
ss = STREAMS[device]
#Sync the oldest stream in the queue with the current
ss[stream_counter].wait_stream(current_stream(device))
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
if is_device_cuda(device):
ss[stream_counter].wait_stream(torch.cuda.current_stream())
elif is_device_xpu(device):
ss[stream_counter].wait_stream(torch.xpu.current_stream())
stream_counters[device] = stream_counter
return ss[stream_counter]
return s
elif is_device_cuda(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.cuda.Stream(device=device, priority=0))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
elif is_device_xpu(device):
@@ -1050,14 +1046,18 @@ def get_offload_stream(device):
ss.append(torch.xpu.Stream(device=device, priority=0))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
return None
def sync_stream(device, stream):
if stream is None or current_stream(device) is None:
if stream is None:
return
current_stream(device).wait_stream(stream)
if is_device_cuda(device):
torch.cuda.current_stream().wait_stream(stream)
elif is_device_xpu(device):
torch.xpu.current_stream().wait_stream(stream)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if device is None or weight.device == device:
@@ -1082,73 +1082,6 @@ def cast_to_device(tensor, device, dtype, copy=False):
non_blocking = device_supports_non_blocking(device)
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
PINNED_MEMORY = {}
TOTAL_PINNED_MEMORY = 0
MAX_PINNED_MEMORY = -1
if not args.disable_pinned_memory:
if is_nvidia() or is_amd():
if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
else:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
def pin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if not is_device_cpu(tensor.device):
return False
if tensor.is_pinned():
#NOTE: Cuda does detect when a tensor is already pinned and would
#error below, but there are proven cases where this also queues an error
#on the GPU async. So dont trust the CUDA API and guard here
return False
size = tensor.numel() * tensor.element_size()
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
return False
ptr = tensor.data_ptr()
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
PINNED_MEMORY[ptr] = size
TOTAL_PINNED_MEMORY += size
return True
return False
def unpin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if not is_device_cpu(tensor.device):
return False
ptr = tensor.data_ptr()
size = tensor.numel() * tensor.element_size()
size_stored = PINNED_MEMORY.get(ptr, None)
if size_stored is None:
logging.warning("Tried to unpin tensor not pinned by ComfyUI")
return False
if size != size_stored:
logging.warning("Size of pinned tensor changed")
return False
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
if len(PINNED_MEMORY) == 0:
TOTAL_PINNED_MEMORY = 0
return True
return False
def sage_attention_enabled():
return args.use_sage_attention

View File

@@ -25,7 +25,7 @@ import logging
import math
import uuid
from typing import Callable, Optional
import time # TODO remove
import torch
import comfy.float
@@ -238,7 +238,6 @@ class ModelPatcher:
self.force_cast_weights = False
self.patches_uuid = uuid.uuid4()
self.parent = None
self.pinned = set()
self.attachments: dict[str] = {}
self.additional_models: dict[str, list[ModelPatcher]] = {}
@@ -276,9 +275,6 @@ class ModelPatcher:
self.size = comfy.model_management.module_size(self.model)
return self.size
def get_ram_usage(self):
return self.model_size()
def loaded_size(self):
return self.model.model_loaded_weight_memory
@@ -298,7 +294,6 @@ class ModelPatcher:
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self
n.pinned = self.pinned
n.force_cast_weights = self.force_cast_weights
@@ -455,19 +450,6 @@ class ModelPatcher:
def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input")
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
rope_options = self.model_options["transformer_options"].get("rope_options", {})
rope_options["scale_x"] = scale_x
rope_options["scale_y"] = scale_y
rope_options["scale_t"] = scale_t
rope_options["shift_x"] = shift_x
rope_options["shift_y"] = shift_y
rope_options["shift_t"] = shift_t
self.model_options["transformer_options"]["rope_options"] = rope_options
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
@@ -609,7 +591,7 @@ class ModelPatcher:
sd.pop(k)
return sd
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, device_final=None):
if key not in self.patches:
return
@@ -629,30 +611,103 @@ class ModelPatcher:
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
if set_func is None:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
if device_final is not None:
out_weight = out_weight.to(device_final)
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
else:
if device_final is not None:
out_weight = out_weight.to(device_final)
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def pin_weight_to_device(self, key):
weight, set_func, convert_func = get_key_weight(self.model, key)
if comfy.model_management.pin_memory(weight):
self.pinned.add(key)
def supports_flipflop(self):
# flipflop requires diffusion_model, explicit flipflop support, NVIDIA CUDA streams, and loading/offloading VRAM
if not comfy.model_management.flipflop_enabled():
return False
if not hasattr(self.model, "diffusion_model"):
return False
if not getattr(self.model.diffusion_model, "enable_flipflop", False):
return False
if not comfy.model_management.is_nvidia():
return False
if comfy.model_management.vram_state in (comfy.model_management.VRAMState.HIGH_VRAM, comfy.model_management.VRAMState.SHARED):
return False
return True
def unpin_weight(self, key):
if key in self.pinned:
weight, set_func, convert_func = get_key_weight(self.model, key)
comfy.model_management.unpin_memory(weight)
self.pinned.remove(key)
def setup_flipflop(self, flipflop_blocks_per_type: dict[str, tuple[int, int]], flipflop_prefixes: list[str]):
if not self.supports_flipflop():
return
logging.info(f"setting up flipflop with {flipflop_blocks_per_type}")
self.model.diffusion_model.setup_flipflop_holders(flipflop_blocks_per_type, flipflop_prefixes, self.load_device, self.offload_device)
def unpin_all_weights(self):
for key in list(self.pinned):
self.unpin_weight(key)
def init_flipflop_block_copies(self) -> int:
if not self.supports_flipflop():
return 0
return self.model.diffusion_model.init_flipflop_block_copies(self.load_device)
def _load_list(self):
def clean_flipflop(self) -> int:
if not self.supports_flipflop():
return 0
return self.model.diffusion_model.clean_flipflop_holders()
def _get_existing_flipflop_prefixes(self):
if self.supports_flipflop():
return self.model.diffusion_model.flipflop_prefixes
return []
def _calc_flipflop_prefixes(self, lowvram_model_memory=0, prepare_flipflop=False):
flipflop_prefixes = []
flipflop_blocks_per_type: dict[str, tuple[int, int]] = {}
if lowvram_model_memory > 0 and self.supports_flipflop():
block_buffer = 3
valid_block_types = []
# for each block type, check if have enough room to flipflop
for block_info in self.model.diffusion_model.get_all_block_module_sizes(reverse_sort_by_size=True):
block_size: int = block_info[1]
if block_size * block_buffer < lowvram_model_memory:
valid_block_types.append(block_info)
# if have candidates for flipping, see how many of each type we have can flipflop
if len(valid_block_types) > 0:
leftover_memory = lowvram_model_memory
for block_info in valid_block_types:
block_type: str = block_info[0]
block_size: int = block_info[1]
total_blocks = len(self.model.diffusion_model.get_all_blocks(block_type))
n_fit_in_memory = int(leftover_memory // block_size)
# if all (or more) of this block type would fit in memory, no need to flipflop with it
if n_fit_in_memory >= total_blocks:
leftover_memory -= total_blocks * block_size
continue
# if the amount of this block that would fit in memory is less than buffer, skip this block type
if n_fit_in_memory < block_buffer:
continue
# 2 blocks worth of VRAM may be needed for flipflop, so make sure to account for them.
flipflop_blocks = min((total_blocks - n_fit_in_memory) + 2, total_blocks)
# for now, work around odd number issue by making it even
if flipflop_blocks % 2 != 0:
if flipflop_blocks == total_blocks:
flipflop_blocks -= 1
else:
flipflop_blocks += 1
flipflop_blocks_per_type[block_type] = (flipflop_blocks, total_blocks)
leftover_memory -= (total_blocks - flipflop_blocks + 2) * block_size
# if there are blocks to flipflop, need to mark their keys
for block_type, (flipflop_blocks, total_blocks) in flipflop_blocks_per_type.items():
# blocks to flipflop are at the end
for i in range(total_blocks-flipflop_blocks, total_blocks):
flipflop_prefixes.append(f"diffusion_model.{block_type}.{i}")
if prepare_flipflop and len(flipflop_blocks_per_type) > 0:
self.setup_flipflop(flipflop_blocks_per_type, flipflop_prefixes)
return flipflop_prefixes
def _load_list(self, lowvram_model_memory=0, prepare_flipflop=False, get_existing_flipflop=False):
loading = []
if get_existing_flipflop:
flipflop_prefixes = self._get_existing_flipflop_prefixes()
else:
flipflop_prefixes = self._calc_flipflop_prefixes(lowvram_model_memory, prepare_flipflop)
for n, m in self.model.named_modules():
params = []
skip = False
@@ -663,7 +718,12 @@ class ModelPatcher:
skip = True # skip random weights in non leaf modules
break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
loading.append((comfy.model_management.module_size(m), n, m, params))
flipflop = False
for prefix in flipflop_prefixes:
if n.startswith(prefix):
flipflop = True
break
loading.append((comfy.model_management.module_size(m), n, m, params, flipflop))
return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
@@ -673,15 +733,18 @@ class ModelPatcher:
patch_counter = 0
lowvram_counter = 0
lowvram_mem_counter = 0
loading = self._load_list()
flipflop_counter = 0
flipflop_mem_counter = 0
loading = self._load_list(lowvram_model_memory, prepare_flipflop=True)
load_completely = []
offloaded = []
load_flipflop = []
loading.sort(reverse=True)
for x in loading:
n = x[1]
m = x[2]
params = x[3]
flipflop: bool = x[4]
module_mem = x[0]
lowvram_weight = False
@@ -689,7 +752,7 @@ class ModelPatcher:
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if not full_load and hasattr(m, "comfy_cast_weights"):
if not full_load and hasattr(m, "comfy_cast_weights") and not flipflop:
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
lowvram_counter += 1
@@ -719,12 +782,15 @@ class ModelPatcher:
patch_counter += 1
cast_weight = True
offloaded.append((module_mem, n, m, params))
else:
if hasattr(m, "comfy_cast_weights"):
wipe_lowvram_weight(m)
if full_load or mem_counter + module_mem < lowvram_model_memory:
if flipflop:
flipflop_counter += 1
flipflop_mem_counter += module_mem
load_flipflop.append((module_mem, n, m, params))
elif full_load or mem_counter + module_mem < lowvram_model_memory:
mem_counter += module_mem
load_completely.append((module_mem, n, m, params))
@@ -740,6 +806,7 @@ class ModelPatcher:
mem_counter += move_weight_functions(m, device_to)
# handle load completely
load_completely.sort(reverse=True)
for x in load_completely:
n = x[1]
@@ -750,9 +817,7 @@ class ModelPatcher:
continue
for param in params:
key = "{}.{}".format(n, param)
self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to)
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
@@ -760,17 +825,36 @@ class ModelPatcher:
for x in load_completely:
x[2].to(device_to)
for x in offloaded:
n = x[1]
params = x[3]
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
# handle flipflop
if len(load_flipflop) > 0:
start_time = time.perf_counter()
load_flipflop.sort(reverse=True)
for x in load_flipflop:
n = x[1]
m = x[2]
params = x[3]
if hasattr(m, "comfy_patched_weights"):
if m.comfy_patched_weights == True:
continue
for param in params:
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to, device_final=self.offload_device)
if lowvram_counter > 0:
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter))
logging.debug("lowvram: loaded module for flipflop {} {}".format(n, m))
end_time = time.perf_counter()
logging.info(f"flipflop load time: {end_time - start_time:.2f} seconds")
start_time = time.perf_counter()
mem_counter += self.init_flipflop_block_copies()
end_time = time.perf_counter()
logging.info(f"flipflop block init time: {end_time - start_time:.2f} seconds")
if lowvram_counter > 0 or flipflop_counter > 0:
if flipflop_counter > 0:
logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable, {mem_counter / (1024 * 1024):.2f} MB loaded, {flipflop_mem_counter / (1024 * 1024):.2f} MB flipflop, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}")
else:
logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable, {mem_counter / (1024 * 1024):.2f} MB loaded, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}")
self.model.model_lowvram = True
else:
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
logging.info(f"loaded completely; {lowvram_model_memory / (1024 * 1024):.2f} MB usable, {mem_counter / (1024 * 1024):.2f} MB loaded, full load: {full_load}")
self.model.model_lowvram = False
if full_load:
self.model.to(device_to)
@@ -807,7 +891,7 @@ class ModelPatcher:
self.eject_model()
if unpatch_weights:
self.unpatch_hooks()
self.unpin_all_weights()
self.clean_flipflop()
if self.model.model_lowvram:
for m in self.model.modules():
move_weight_functions(m, device_to)
@@ -847,8 +931,9 @@ class ModelPatcher:
with self.use_ejected():
hooks_unpatched = False
memory_freed = 0
memory_freed += self.clean_flipflop()
patch_counter = 0
unload_list = self._load_list()
unload_list = self._load_list(get_existing_flipflop=True)
unload_list.sort()
for unload in unload_list:
if memory_to_free < memory_freed:
@@ -857,7 +942,10 @@ class ModelPatcher:
n = unload[1]
m = unload[2]
params = unload[3]
flipflop: bool = unload[4]
if flipflop:
continue
lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
move_weight = True
@@ -903,9 +991,6 @@ class ModelPatcher:
memory_freed += module_mem
logging.debug("freed {}".format(n))
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed
@@ -1308,6 +1393,5 @@ class ModelPatcher:
self.clear_cached_hook_weights()
def __del__(self):
self.unpin_all_weights()
self.detach(unpatch_all=False)

View File

@@ -35,7 +35,7 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
try:
if torch.cuda.is_available() and comfy.model_management.WINDOWS:
if torch.cuda.is_available():
from torch.nn.attention import SDPBackend, sdpa_kernel
import inspect
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
@@ -70,11 +70,8 @@ cast_to = comfy.model_management.cast_to #TODO: remove once no more references
def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
@torch.compiler.disable()
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
if dtype is None:
dtype = input.dtype
@@ -83,58 +80,32 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if device is None:
device = input.device
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
offload_stream = comfy.model_management.get_offload_stream(device)
else:
offload_stream = None
offload_stream = comfy.model_management.get_offload_stream(device)
if offload_stream is not None:
wf_context = offload_stream
else:
wf_context = contextlib.nullcontext()
non_blocking = comfy.model_management.device_supports_non_blocking(device)
weight_has_function = len(s.weight_function) > 0
bias_has_function = len(s.bias_function) > 0
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
has_function = len(s.bias_function) > 0
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if bias_has_function:
if has_function:
with wf_context:
for f in s.bias_function:
bias = f(bias)
weight = weight.to(dtype=dtype)
if weight_has_function:
has_function = len(s.weight_function) > 0
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function:
with wf_context:
for f in s.weight_function:
weight = f(weight)
comfy.model_management.sync_stream(device, offload_stream)
if offloadable:
return weight, bias, offload_stream
else:
#Legacy function signature
return weight, bias
def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
if weight is not None:
device = weight.device
else:
if bias is None:
return
device = bias.device
offload_stream.wait_stream(comfy.model_management.current_stream(device))
return weight, bias
class CastWeightBiasOp:
comfy_cast_weights = False
@@ -147,10 +118,8 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
run_every_op()
@@ -164,10 +133,8 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
run_every_op()
@@ -181,10 +148,8 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
run_every_op()
@@ -207,10 +172,8 @@ class disable_weight_init:
return super()._conv_forward(input, weight, bias, *args, **kwargs)
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
run_every_op()
@@ -224,10 +187,8 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
run_every_op()
@@ -242,14 +203,11 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
weight, bias = cast_bias_weight(self, input)
else:
weight = None
bias = None
offload_stream = None
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
run_every_op()
@@ -265,15 +223,11 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
weight, bias = cast_bias_weight(self, input)
else:
weight = None
bias = None
offload_stream = None
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
def forward(self, *args, **kwargs):
run_every_op()
@@ -292,12 +246,10 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.conv_transpose2d(
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -316,12 +268,10 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.conv_transpose1d(
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -339,11 +289,8 @@ class disable_weight_init:
output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None
weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
def forward(self, *args, **kwargs):
run_every_op()
@@ -405,10 +352,16 @@ def fp8_linear(self, input):
if dtype not in [torch.float8_e4m3fn]:
return None
tensor_2d = False
if len(input.shape) == 2:
tensor_2d = True
input = input.unsqueeze(1)
input_shape = input.shape
input_dtype = input.dtype
if input.ndim == 3 or input.ndim == 2:
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
scale_weight = self.scale_weight
scale_input = self.scale_input
@@ -419,21 +372,19 @@ def fp8_linear(self, input):
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
else:
scale_input = scale_input.to(input.device)
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
uncast_bias_weight(self, w, bias, offload_stream)
return o
if tensor_2d:
return o.reshape(input_shape[0], -1)
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
return None
@@ -453,10 +404,8 @@ class fp8_ops(manual_cast):
except Exception as e:
logging.info("Exception during fp8 op: {}".format(e))
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
@@ -484,14 +433,12 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
if out is not None:
return out
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
weight, bias = cast_bias_weight(self, input)
if weight.numel() < input.numel(): #TODO: optimize
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
else:
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
def convert_weight(self, weight, inplace=False, **kwargs):
if inplace:
@@ -534,12 +481,12 @@ if CUBLAS_IS_AVAILABLE:
# ==============================================================================
# Mixed Precision Operations
# ==============================================================================
from .quant_ops import QuantizedTensor
from .quant_ops import QuantizedTensor, TensorCoreFP8Layout
QUANT_FORMAT_MIXINS = {
"float8_e4m3fn": {
"dtype": torch.float8_e4m3fn,
"layout_type": "TensorCoreFP8Layout",
"layout_type": TensorCoreFP8Layout,
"parameters": {
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
@@ -630,10 +577,8 @@ class MixedPrecisionOps(disable_weight_init):
return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
weight, bias = cast_bias_weight(self, input)
return self._forward(input, weight, bias)
def forward(self, input, *args, **kwargs):
run_every_op()

View File

@@ -123,15 +123,15 @@ class QuantizedTensor(torch.Tensor):
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
return torch.Tensor._make_subclass(cls, qdata, require_grad=False)
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata
self._qdata = qdata.contiguous()
self._layout_type = layout_type
self._layout_params = layout_params
def __repr__(self):
layout_name = self._layout_type
layout_name = self._layout_type.__name__
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
@@ -179,15 +179,15 @@ class QuantizedTensor(torch.Tensor):
attr_name = f"_layout_param_{key}"
layout_params[key] = inner_tensors[attr_name]
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params)
@classmethod
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs)
return cls(qdata, layout_type, layout_params)
def dequantize(self) -> torch.Tensor:
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
return self._layout_type.dequantize(self._qdata, **self._layout_params)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
@@ -357,10 +357,9 @@ class TensorCoreFP8Layout(QuantizedLayout):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
# lp_amax = torch.finfo(dtype).max
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
lp_amax = torch.finfo(dtype).max
tensor_scaled = tensor.float() / scale
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
layout_params = {
@@ -379,12 +378,7 @@ class TensorCoreFP8Layout(QuantizedLayout):
return qtensor._qdata, qtensor._layout_params['scale']
LAYOUTS = {
"TensorCoreFP8Layout": TensorCoreFP8Layout,
}
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
def fp8_linear(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
@@ -411,17 +405,13 @@ def fp8_linear(func, args, kwargs):
try:
output = torch._scaled_mm(
plain_input.reshape(-1, input_shape[2]).contiguous(),
plain_input.reshape(-1, input_shape[2]),
weight_t,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
if not tensor_2d:
output = output.reshape((-1, input_shape[1], weight.shape[0]))
@@ -431,7 +421,7 @@ def fp8_linear(func, args, kwargs):
'scale': output_scale,
'orig_dtype': input_tensor._layout_params['orig_dtype']
}
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
return QuantizedTensor(output, TensorCoreFP8Layout, output_params)
else:
return output
@@ -445,68 +435,3 @@ def fp8_linear(func, args, kwargs):
input_tensor = input_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight, bias)
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
output = torch._scaled_mm(
plain_input.contiguous(),
plain_weight,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
return output
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
def fp8_addmm(func, args, kwargs):
input_tensor = args[1]
weight = args[2]
bias = args[0]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
if isinstance(args[2], QuantizedTensor):
a[2] = args[2].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
def fp8_mm(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
def fp8_func(func, args, kwargs):
input_tensor = args[0]
if isinstance(input_tensor, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
ar = list(args)
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)

View File

@@ -143,9 +143,6 @@ class CLIP:
n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n
def get_ram_usage(self):
return self.patcher.get_ram_usage()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model)
@@ -296,7 +293,6 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
self.not_video = False
self.size = None
self.downscale_index_formula = None
self.upscale_index_formula = None
@@ -599,16 +595,6 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
self.model_size()
def model_size(self):
if self.size is not None:
return self.size
self.size = comfy.model_management.module_size(self.first_stage_model)
return self.size
def get_ram_usage(self):
return self.model_size()
def throw_exception_if_invalid(self):
if self.first_stage_model is None:
@@ -1344,7 +1330,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
else:
unet_dtype = dtype
if model_config.layer_quant_config is not None:
if hasattr(model_config, "layer_quant_config"):
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)

View File

@@ -0,0 +1,261 @@
from __future__ import annotations
import aiohttp
import mimetypes
from typing import Optional, Union
from comfy.utils import common_upscale
from comfy_api_nodes.apis.client import (
ApiClient,
ApiEndpoint,
HttpMethod,
SynchronousOperation,
UploadRequest,
UploadResponse,
)
from server import PromptServer
from comfy.cli_args import args
import numpy as np
from PIL import Image
import torch
import math
import base64
from .util import tensor_to_bytesio, bytesio_to_image_tensor
from io import BytesIO
async def validate_and_cast_response(
response, timeout: int = None, node_id: Union[str, None] = None
) -> torch.Tensor:
"""Validates and casts a response to a torch.Tensor.
Args:
response: The response to validate and cast.
timeout: Request timeout in seconds. Defaults to None (no timeout).
Returns:
A torch.Tensor representing the image (1, H, W, C).
Raises:
ValueError: If the response is not valid.
"""
# validate raw JSON response
data = response.data
if not data or len(data) == 0:
raise ValueError("No images returned from API endpoint")
# Initialize list to store image tensors
image_tensors: list[torch.Tensor] = []
# Process each image in the data array
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
for img_data in data:
img_bytes: bytes
if img_data.b64_json:
img_bytes = base64.b64decode(img_data.b64_json)
elif img_data.url:
if node_id:
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
async with session.get(img_data.url) as resp:
if resp.status != 200:
raise ValueError("Failed to download generated image")
img_bytes = await resp.read()
else:
raise ValueError("Invalid image payload neither URL nor base64 data present.")
pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
arr = np.asarray(pil_img).astype(np.float32) / 255.0
image_tensors.append(torch.from_numpy(arr))
return torch.stack(image_tensors, dim=0)
def validate_aspect_ratio(
aspect_ratio: str,
minimum_ratio: float,
maximum_ratio: float,
minimum_ratio_str: str,
maximum_ratio_str: str,
) -> float:
"""Validates and casts an aspect ratio string to a float.
Args:
aspect_ratio: The aspect ratio string to validate.
minimum_ratio: The minimum aspect ratio.
maximum_ratio: The maximum aspect ratio.
minimum_ratio_str: The minimum aspect ratio string.
maximum_ratio_str: The maximum aspect ratio string.
Returns:
The validated and cast aspect ratio.
Raises:
Exception: If the aspect ratio is not valid.
"""
# get ratio values
numbers = aspect_ratio.split(":")
if len(numbers) != 2:
raise TypeError(
f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}."
)
try:
numerator = int(numbers[0])
denominator = int(numbers[1])
except ValueError as exc:
raise TypeError(
f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}."
) from exc
calculated_ratio = numerator / denominator
# if not close to minimum and maximum, check bounds
if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose(
calculated_ratio, maximum_ratio
):
if calculated_ratio < minimum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
if calculated_ratio > maximum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
return aspect_ratio
async def download_url_to_bytesio(
url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
) -> BytesIO:
"""Downloads content from a URL using requests and returns it as BytesIO.
Args:
url: The URL to download.
timeout: Request timeout in seconds. Defaults to None (no timeout).
Returns:
BytesIO object containing the downloaded content.
"""
headers = {}
if url.startswith("/proxy/"):
url = str(args.comfy_api_base).rstrip("/") + url
auth_token = auth_kwargs.get("auth_token")
comfy_api_key = auth_kwargs.get("comfy_api_key")
if auth_token:
headers["Authorization"] = f"Bearer {auth_token}"
elif comfy_api_key:
headers["X-API-KEY"] = comfy_api_key
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
async with session.get(url, headers=headers) as resp:
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
return BytesIO(await resp.read())
def process_image_response(response_content: bytes | str) -> torch.Tensor:
"""Uses content from a Response object and converts it to a torch.Tensor"""
return bytesio_to_image_tensor(BytesIO(response_content))
def text_filepath_to_base64_string(filepath: str) -> str:
"""Converts a text file to a base64 string."""
with open(filepath, "rb") as f:
file_content = f.read()
return base64.b64encode(file_content).decode("utf-8")
def text_filepath_to_data_uri(filepath: str) -> str:
"""Converts a text file to a data URI."""
base64_string = text_filepath_to_base64_string(filepath)
mime_type, _ = mimetypes.guess_type(filepath)
if mime_type is None:
mime_type = "application/octet-stream"
return f"data:{mime_type};base64,{base64_string}"
async def upload_file_to_comfyapi(
file_bytes_io: BytesIO,
filename: str,
upload_mime_type: Optional[str],
auth_kwargs: Optional[dict[str, str]] = None,
) -> str:
"""
Uploads a single file to ComfyUI API and returns its download URL.
Args:
file_bytes_io: BytesIO object containing the file data.
filename: The filename of the file.
upload_mime_type: MIME type of the file.
auth_kwargs: Optional authentication token(s).
Returns:
The download URL for the uploaded file.
"""
if upload_mime_type is None:
request_object = UploadRequest(file_name=filename)
else:
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/customers/storage",
method=HttpMethod.POST,
request_model=UploadRequest,
response_model=UploadResponse,
),
request=request_object,
auth_kwargs=auth_kwargs,
)
response: UploadResponse = await operation.execute()
await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type)
return response.download_url
async def upload_images_to_comfyapi(
image: torch.Tensor,
max_images=8,
auth_kwargs: Optional[dict[str, str]] = None,
mime_type: Optional[str] = None,
) -> list[str]:
"""
Uploads images to ComfyUI API and returns download URLs.
To upload multiple images, stack them in the batch dimension first.
Args:
image: Input torch.Tensor image.
max_images: Maximum number of images to upload.
auth_kwargs: Optional authentication token(s).
mime_type: Optional MIME type for the image.
"""
# if batch, try to upload each file if max_images is greater than 0
download_urls: list[str] = []
is_batch = len(image.shape) > 3
batch_len = image.shape[0] if is_batch else 1
for idx in range(min(batch_len, max_images)):
tensor = image[idx] if is_batch else image
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs)
download_urls.append(url)
return download_urls
def resize_mask_to_image(
mask: torch.Tensor,
image: torch.Tensor,
upscale_method="nearest-exact",
crop="disabled",
allow_gradient=True,
add_channel_dim=False,
):
"""
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
"""
_, H, W, _ = image.shape
mask = mask.unsqueeze(-1)
mask = mask.movedim(-1, 1)
mask = common_upscale(
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
)
mask = mask.movedim(1, -1)
if not add_channel_dim:
mask = mask.squeeze(-1)
if not allow_gradient:
mask = (mask > 0.5).float()
return mask

View File

@@ -0,0 +1,17 @@
# generated by datamodel-codegen:
# filename: filtered-openapi.yaml
# timestamp: 2025-04-29T23:44:54+00:00
from __future__ import annotations
from typing import Optional
from pydantic import BaseModel
from . import PixverseDto
class ResponseData(BaseModel):
ErrCode: Optional[int] = None
ErrMsg: Optional[str] = None
Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None

View File

@@ -0,0 +1,57 @@
# generated by datamodel-codegen:
# filename: filtered-openapi.yaml
# timestamp: 2025-04-29T23:44:54+00:00
from __future__ import annotations
from typing import Optional
from pydantic import BaseModel, Field
class V2OpenAPII2VResp(BaseModel):
video_id: Optional[int] = Field(None, description='Video_id')
class V2OpenAPIT2VReq(BaseModel):
aspect_ratio: str = Field(
..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9']
)
duration: int = Field(
...,
description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)',
examples=[5],
)
model: str = Field(
..., description='Model version (only supports v3.5)', examples=['v3.5']
)
motion_mode: Optional[str] = Field(
'normal',
description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)',
examples=['normal'],
)
negative_prompt: Optional[str] = Field(
None, description='Negative prompt\n', max_length=2048
)
prompt: str = Field(..., description='Prompt', max_length=2048)
quality: str = Field(
...,
description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")',
examples=['540p'],
)
seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647')
style: Optional[str] = Field(
None,
description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed',
examples=['anime'],
)
template_id: Optional[int] = Field(
None,
description='Template ID (template_id must be activated before use)',
examples=[302325299692608],
)
water_mark: Optional[bool] = Field(
False,
description='Watermark (true: add watermark, false: no watermark)',
examples=[False],
)

View File

@@ -0,0 +1,981 @@
"""
API Client Framework for api.comfy.org.
This module provides a flexible framework for making API requests from ComfyUI nodes.
It supports both synchronous and asynchronous API operations with proper type validation.
Key Components:
--------------
1. ApiClient - Handles HTTP requests with authentication and error handling
2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models
3. ApiOperation - Executes a single synchronous API operation
Usage Examples:
--------------
# Example 1: Synchronous API Operation
# ------------------------------------
# For a simple API call that returns the result immediately:
# 1. Create the API client
api_client = ApiClient(
base_url="https://api.example.com",
auth_token="your_auth_token_here",
comfy_api_key="your_comfy_api_key_here",
timeout=30.0,
verify_ssl=True
)
# 2. Define the endpoint
user_info_endpoint = ApiEndpoint(
path="/v1/users/me",
method=HttpMethod.GET,
request_model=EmptyRequest, # No request body needed
response_model=UserProfile, # Pydantic model for the response
query_params=None
)
# 3. Create the request object
request = EmptyRequest()
# 4. Create and execute the operation
operation = ApiOperation(
endpoint=user_info_endpoint,
request=request
)
user_profile = await operation.execute(client=api_client) # Returns immediately with the result
# Example 2: Asynchronous API Operation with Polling
# -------------------------------------------------
# For an API that starts a task and requires polling for completion:
# 1. Define the endpoints (initial request and polling)
generate_image_endpoint = ApiEndpoint(
path="/v1/images/generate",
method=HttpMethod.POST,
request_model=ImageGenerationRequest,
response_model=TaskCreatedResponse,
query_params=None
)
check_task_endpoint = ApiEndpoint(
path="/v1/tasks/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=ImageGenerationResult,
query_params=None
)
# 2. Create the request object
request = ImageGenerationRequest(
prompt="a beautiful sunset over mountains",
width=1024,
height=1024,
num_images=1
)
# 3. Create and execute the polling operation
operation = PollingOperation(
initial_endpoint=generate_image_endpoint,
initial_request=request,
poll_endpoint=check_task_endpoint,
task_id_field="task_id",
status_field="status",
completed_statuses=["completed"],
failed_statuses=["failed", "error"]
)
# This will make the initial request and then poll until completion
result = await operation.execute(client=api_client) # Returns the final ImageGenerationResult when done
"""
from __future__ import annotations
import aiohttp
import asyncio
import logging
import io
import os
import socket
from aiohttp.client_exceptions import ClientError, ClientResponseError
from typing import Type, Optional, Any, TypeVar, Generic, Callable
from enum import Enum
import json
from urllib.parse import urljoin, urlparse
from pydantic import BaseModel, Field
import uuid # For generating unique operation IDs
from server import PromptServer
from comfy.cli_args import args
from comfy import utils
from . import request_logger
T = TypeVar("T", bound=BaseModel)
R = TypeVar("R", bound=BaseModel)
P = TypeVar("P", bound=BaseModel) # For poll response
PROGRESS_BAR_MAX = 100
class NetworkError(Exception):
"""Base exception for network-related errors with diagnostic information."""
pass
class LocalNetworkError(NetworkError):
"""Exception raised when local network connectivity issues are detected."""
pass
class ApiServerError(NetworkError):
"""Exception raised when the API server is unreachable but internet is working."""
pass
class EmptyRequest(BaseModel):
"""Base class for empty request bodies.
For GET requests, fields will be sent as query parameters."""
pass
class UploadRequest(BaseModel):
file_name: str = Field(..., description="Filename to upload")
content_type: Optional[str] = Field(
None,
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
)
class UploadResponse(BaseModel):
download_url: str = Field(..., description="URL to GET uploaded file")
upload_url: str = Field(..., description="URL to PUT file to upload")
class HttpMethod(str, Enum):
GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
PATCH = "PATCH"
class ApiClient:
"""
Client for making HTTP requests to an API with authentication, error handling, and retry logic.
"""
def __init__(
self,
base_url: str,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
timeout: float = 3600.0,
verify_ssl: bool = True,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
retry_status_codes: Optional[tuple[int, ...]] = None,
session: Optional[aiohttp.ClientSession] = None,
):
self.base_url = base_url
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
self.timeout = timeout
self.verify_ssl = verify_ssl
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
# Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests),
# 500, 502, 503, 504 (Server Errors)
self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504)
self._session: Optional[aiohttp.ClientSession] = session
self._owns_session = session is None # Track if we have to close it
@staticmethod
def _generate_operation_id(path: str) -> str:
"""Generates a unique operation ID for logging."""
return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}"
@staticmethod
def _create_json_payload_args(
data: Optional[dict[str, Any]] = None,
headers: Optional[dict[str, str]] = None,
) -> dict[str, Any]:
return {
"json": data,
"headers": headers,
}
def _create_form_data_args(
self,
data: dict[str, Any] | None,
files: dict[str, Any] | None,
headers: Optional[dict[str, str]] = None,
multipart_parser: Callable | None = None,
) -> dict[str, Any]:
if headers and "Content-Type" in headers:
del headers["Content-Type"]
if multipart_parser and data:
data = multipart_parser(data)
if isinstance(data, aiohttp.FormData):
form = data # If the parser already returned a FormData, pass it through
else:
form = aiohttp.FormData(default_to_multipart=True)
if data: # regular text fields
for k, v in data.items():
if v is None:
continue # aiohttp fails to serialize "None" values
# aiohttp expects strings or bytes; convert enums etc.
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
if files:
file_iter = files if isinstance(files, list) else files.items()
for field_name, file_obj in file_iter:
if file_obj is None:
continue # aiohttp fails to serialize "None" values
# file_obj can be (filename, bytes/io.BytesIO, content_type) tuple
if isinstance(file_obj, tuple):
filename, file_value, content_type = self._unpack_tuple(file_obj)
else:
file_value = file_obj
filename = getattr(file_obj, "name", field_name)
content_type = "application/octet-stream"
form.add_field(
name=field_name,
value=file_value,
filename=filename,
content_type=content_type,
)
return {"data": form, "headers": headers or {}}
@staticmethod
def _create_urlencoded_form_data_args(
data: dict[str, Any],
headers: Optional[dict[str, str]] = None,
) -> dict[str, Any]:
headers = headers or {}
headers["Content-Type"] = "application/x-www-form-urlencoded"
return {
"data": data,
"headers": headers,
}
def get_headers(self) -> dict[str, str]:
"""Get headers for API requests, including authentication if available"""
headers = {"Content-Type": "application/json", "Accept": "application/json"}
if self.auth_token:
headers["Authorization"] = f"Bearer {self.auth_token}"
elif self.comfy_api_key:
headers["X-API-KEY"] = self.comfy_api_key
return headers
async def _check_connectivity(self, target_url: str) -> dict[str, bool]:
"""
Check connectivity to determine if network issues are local or server-related.
Args:
target_url: URL to check connectivity to
Returns:
Dictionary with connectivity status details
"""
results = {
"internet_accessible": False,
"api_accessible": False,
"is_local_issue": False,
"is_api_issue": False,
}
timeout = aiohttp.ClientTimeout(total=5.0)
async with aiohttp.ClientSession(timeout=timeout) as session:
try:
async with session.get("https://www.google.com", ssl=self.verify_ssl) as resp:
results["internet_accessible"] = resp.status < 500
except (ClientError, asyncio.TimeoutError, socket.gaierror):
results["is_local_issue"] = True
return results # cannot reach the internet early exit
# Now check API health endpoint
parsed = urlparse(target_url)
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
try:
async with session.get(health_url, ssl=self.verify_ssl) as resp:
results["api_accessible"] = resp.status < 500
except ClientError:
pass # leave as False
results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"]
return results
async def request(
self,
method: str,
path: str,
params: Optional[dict[str, Any]] = None,
data: Optional[dict[str, Any]] = None,
files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
headers: Optional[dict[str, str]] = None,
content_type: str = "application/json",
multipart_parser: Callable | None = None,
retry_count: int = 0, # Used internally for tracking retries
) -> dict[str, Any]:
"""
Make an HTTP request to the API with automatic retries for transient errors.
Args:
method: HTTP method (GET, POST, etc.)
path: API endpoint path (will be joined with base_url)
params: Query parameters
data: body data
files: Files to upload
headers: Additional headers
content_type: Content type of the request. Defaults to application/json.
retry_count: Internal parameter for tracking retries, do not set manually
Returns:
Parsed JSON response
Raises:
LocalNetworkError: If local network connectivity issues are detected
ApiServerError: If the API server is unreachable but internet is working
Exception: For other request failures
"""
# Build full URL and merge headers
relative_path = path.lstrip("/")
url = urljoin(self.base_url, relative_path)
self._check_auth(self.auth_token, self.comfy_api_key)
request_headers = self.get_headers()
if headers:
request_headers.update(headers)
if files:
request_headers.pop("Content-Type", None)
if params:
params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values
logging.debug("[DEBUG] Request Headers: %s", request_headers)
logging.debug("[DEBUG] Files: %s", files)
logging.debug("[DEBUG] Params: %s", params)
logging.debug("[DEBUG] Data: %s", data)
if content_type == "application/x-www-form-urlencoded":
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers)
elif content_type == "multipart/form-data":
payload_args = self._create_form_data_args(data, files, request_headers, multipart_parser)
else:
payload_args = self._create_json_payload_args(data, request_headers)
operation_id = self._generate_operation_id(path)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=request_headers,
request_params=params,
request_data=data if content_type == "application/json" else "[form-data or other]",
)
session = await self._get_session()
try:
async with session.request(
method,
url,
params=params,
ssl=self.verify_ssl,
**payload_args,
) as resp:
if resp.status >= 400:
try:
error_data = await resp.json()
except (aiohttp.ContentTypeError, json.JSONDecodeError):
error_data = await resp.text()
return await self._handle_http_error(
ClientResponseError(resp.request_info, resp.history, status=resp.status, message=error_data),
operation_id,
method,
url,
params,
data,
files,
headers,
content_type,
multipart_parser,
retry_count=retry_count,
response_content=error_data,
)
# Success parse JSON (safely) and log
try:
payload = await resp.json()
response_content_to_log = payload
except (aiohttp.ContentTypeError, json.JSONDecodeError):
payload = {}
response_content_to_log = await resp.text()
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=response_content_to_log,
)
return payload
except (ClientError, asyncio.TimeoutError, socket.gaierror) as e:
# Treat as *connection* problem optionally retry, else escalate
if retry_count < self.max_retries:
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning("Connection error. Retrying in %.2fs (%s/%s): %s", delay, retry_count + 1,
self.max_retries, str(e))
await asyncio.sleep(delay)
return await self.request(
method,
path,
params=params,
data=data,
files=files,
headers=headers,
content_type=content_type,
multipart_parser=multipart_parser,
retry_count=retry_count + 1,
)
# One final connectivity check for diagnostics
connectivity = await self._check_connectivity(self.base_url)
if connectivity["is_local_issue"]:
raise LocalNetworkError(
"Unable to connect to the API server due to local network issues. "
"Please check your internet connection and try again."
) from e
raise ApiServerError(
f"The API server at {self.base_url} is currently unreachable. "
f"The service may be experiencing issues. Please try again later."
) from e
@staticmethod
def _check_auth(auth_token, comfy_api_key):
"""Verify that an auth token is present or comfy_api_key is present"""
if auth_token is None and comfy_api_key is None:
raise Exception("Unauthorized: Please login first to use this node.")
return auth_token or comfy_api_key
@staticmethod
async def upload_file(
upload_url: str,
file: io.BytesIO | str,
content_type: str | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
) -> aiohttp.ClientResponse:
"""Upload a file to the API with retry logic.
Args:
upload_url: The URL to upload to
file: Either a file path string, BytesIO object, or tuple of (file_path, filename)
content_type: Optional mime type to set for the upload
max_retries: Maximum number of retry attempts
retry_delay: Initial delay between retries in seconds
retry_backoff_factor: Multiplier for the delay after each retry
"""
headers: dict[str, str] = {}
skip_auto_headers: set[str] = set()
if content_type:
headers["Content-Type"] = content_type
else:
# tell aiohttp not to add Content-Type that will break the request signature and result in a 403 status.
skip_auto_headers.add("Content-Type")
# Extract file bytes
if isinstance(file, io.BytesIO):
file.seek(0)
data = file.read()
elif isinstance(file, str):
with open(file, "rb") as f:
data = f.read()
else:
raise ValueError("File must be BytesIO or str path")
parsed = urlparse(upload_url)
basename = os.path.basename(parsed.path) or parsed.netloc or "upload"
operation_id = f"upload_{basename}_{uuid.uuid4().hex[:8]}"
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers,
request_data=f"[File data {len(data)} bytes]",
)
delay = retry_delay
for attempt in range(max_retries + 1):
try:
timeout = aiohttp.ClientTimeout(total=None) # honour server side timeouts
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.put(
upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers,
) as resp:
resp.raise_for_status()
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content="File uploaded successfully.",
)
return resp
except (ClientError, asyncio.TimeoutError) as e:
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
response_status_code=e.status if hasattr(e, "status") else None,
response_headers=dict(e.headers) if hasattr(e, "headers") else None,
response_content=None,
error_message=f"{type(e).__name__}: {str(e)}",
)
if attempt < max_retries:
logging.warning(
"Upload failed (%s/%s). Retrying in %.2fs. %s", attempt + 1, max_retries, delay, str(e)
)
await asyncio.sleep(delay)
delay *= retry_backoff_factor
else:
raise NetworkError(f"Failed to upload file after {max_retries + 1} attempts: {e}") from e
async def _handle_http_error(
self,
exc: ClientResponseError,
operation_id: str,
*req_meta,
retry_count: int,
response_content: dict | str = "",
) -> dict[str, Any]:
status_code = exc.status
if status_code == 401:
user_friendly = "Unauthorized: Please login first to use this node."
elif status_code == 402:
user_friendly = "Payment Required: Please add credits to your account to use this node."
elif status_code == 409:
user_friendly = "There is a problem with your account. Please contact support@comfy.org."
elif status_code == 429:
user_friendly = "Rate Limit Exceeded: Please try again later."
else:
if isinstance(response_content, dict):
if "error" in response_content and "message" in response_content["error"]:
user_friendly = f"API Error: {response_content['error']['message']}"
if "type" in response_content["error"]:
user_friendly += f" (Type: {response_content['error']['type']})"
else: # Handle cases where error is just a JSON dict with unknown format
user_friendly = f"API Error: {json.dumps(response_content)}"
else:
if len(response_content) < 200: # Arbitrary limit for display
user_friendly = f"API Error (raw): {response_content}"
else:
user_friendly = f"API Error (raw, status {response_content})"
request_logger.log_request_response(
operation_id=operation_id,
request_method=req_meta[0],
request_url=req_meta[1],
response_status_code=exc.status,
response_headers=dict(req_meta[5]) if req_meta[5] else None,
response_content=response_content,
error_message=f"HTTP Error {exc.status}",
)
logging.debug("[DEBUG] API Error: %s (Status: %s)", user_friendly, status_code)
if response_content:
logging.debug("[DEBUG] Response content: %s", response_content)
# Retry if eligible
if status_code in self.retry_status_codes and retry_count < self.max_retries:
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning(
"HTTP error %s. Retrying in %.2fs (%s/%s)",
status_code,
delay,
retry_count + 1,
self.max_retries,
)
await asyncio.sleep(delay)
return await self.request(
req_meta[0], # method
req_meta[1].replace(self.base_url, ""), # path
params=req_meta[2],
data=req_meta[3],
files=req_meta[4],
headers=req_meta[5],
content_type=req_meta[6],
multipart_parser=req_meta[7],
retry_count=retry_count + 1,
)
raise Exception(user_friendly) from exc
@staticmethod
def _unpack_tuple(t):
"""Helper to normalise (filename, file, content_type) tuples."""
if len(t) == 3:
return t
elif len(t) == 2:
return t[0], t[1], "application/octet-stream"
else:
raise ValueError("files tuple must be (filename, file[, content_type])")
async def _get_session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
timeout = aiohttp.ClientTimeout(total=self.timeout)
self._session = aiohttp.ClientSession(timeout=timeout)
self._owns_session = True
return self._session
async def close(self) -> None:
if self._owns_session and self._session and not self._session.closed:
await self._session.close()
async def __aenter__(self) -> "ApiClient":
"""Allow usage as asynccontextmanager ensures clean teardown"""
return self
async def __aexit__(self, exc_type, exc, tb):
await self.close()
class ApiEndpoint(Generic[T, R]):
"""Defines an API endpoint with its request and response types"""
def __init__(
self,
path: str,
method: HttpMethod,
request_model: Type[T],
response_model: Type[R],
query_params: Optional[dict[str, Any]] = None,
):
"""Initialize an API endpoint definition.
Args:
path: The URL path for this endpoint, can include placeholders like {id}
method: The HTTP method to use (GET, POST, etc.)
request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint
response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint
query_params: Optional dictionary of query parameters to include in the request
"""
self.path = path
self.method = method
self.request_model = request_model
self.response_model = response_model
self.query_params = query_params or {}
class SynchronousOperation(Generic[T, R]):
"""Represents a single synchronous API operation."""
def __init__(
self,
endpoint: ApiEndpoint[T, R],
request: T,
files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[dict[str, str]] = None,
timeout: float = 7200.0,
verify_ssl: bool = True,
content_type: str = "application/json",
multipart_parser: Callable | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
) -> None:
self.endpoint = endpoint
self.request = request
self.files = files
self.api_base: str = api_base or args.comfy_api_base
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
if auth_kwargs is not None:
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
self.timeout = timeout
self.verify_ssl = verify_ssl
self.content_type = content_type
self.multipart_parser = multipart_parser
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
async def execute(self, client: Optional[ApiClient] = None) -> R:
owns_client = client is None
if owns_client:
client = ApiClient(
base_url=self.api_base,
auth_token=self.auth_token,
comfy_api_key=self.comfy_api_key,
timeout=self.timeout,
verify_ssl=self.verify_ssl,
max_retries=self.max_retries,
retry_delay=self.retry_delay,
retry_backoff_factor=self.retry_backoff_factor,
)
try:
request_dict: Optional[dict[str, Any]]
if isinstance(self.request, EmptyRequest):
request_dict = None
else:
request_dict = self.request.model_dump(exclude_none=True)
for k, v in list(request_dict.items()):
if isinstance(v, Enum):
request_dict[k] = v.value
logging.debug("[DEBUG] API Request: %s %s", self.endpoint.method.value, self.endpoint.path)
logging.debug("[DEBUG] Request Data: %s", json.dumps(request_dict, indent=2))
logging.debug("[DEBUG] Query Params: %s", self.endpoint.query_params)
response_json = await client.request(
self.endpoint.method.value,
self.endpoint.path,
params=self.endpoint.query_params,
data=request_dict,
files=self.files,
content_type=self.content_type,
multipart_parser=self.multipart_parser,
)
logging.debug("=" * 50)
logging.debug("[DEBUG] RESPONSE DETAILS:")
logging.debug("[DEBUG] Status Code: 200 (Success)")
logging.debug("[DEBUG] Response Body: %s", json.dumps(response_json, indent=2))
logging.debug("=" * 50)
parsed_response = self.endpoint.response_model.model_validate(response_json)
logging.debug("[DEBUG] Parsed Response: %s", parsed_response)
return parsed_response
finally:
if owns_client:
await client.close()
class TaskStatus(str, Enum):
"""Enum for task status values"""
COMPLETED = "completed"
FAILED = "failed"
PENDING = "pending"
class PollingOperation(Generic[T, R]):
"""Represents an asynchronous API operation that requires polling for completion."""
def __init__(
self,
poll_endpoint: ApiEndpoint[EmptyRequest, R],
completed_statuses: list[str],
failed_statuses: list[str],
*,
status_extractor: Callable[[R], Optional[str]],
progress_extractor: Callable[[R], Optional[float]] | None = None,
result_url_extractor: Callable[[R], Optional[str]] | None = None,
price_extractor: Callable[[R], Optional[float]] | None = None,
request: Optional[T] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[dict[str, str]] = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
max_retries: int = 3, # Max retries per individual API call
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
estimated_duration: Optional[float] = None,
node_id: Optional[str] = None,
) -> None:
self.poll_endpoint = poll_endpoint
self.request = request
self.api_base: str = api_base or args.comfy_api_base
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
if auth_kwargs is not None:
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
self.poll_interval = poll_interval
self.max_poll_attempts = max_poll_attempts
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
self.estimated_duration = estimated_duration
self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None))
self.progress_extractor = progress_extractor
self.result_url_extractor = result_url_extractor
self.price_extractor = price_extractor
self.node_id = node_id
self.completed_statuses = completed_statuses
self.failed_statuses = failed_statuses
self.final_response: Optional[R] = None
self.extracted_price: Optional[float] = None
async def execute(self, client: Optional[ApiClient] = None) -> R:
owns_client = client is None
if owns_client:
client = ApiClient(
base_url=self.api_base,
auth_token=self.auth_token,
comfy_api_key=self.comfy_api_key,
max_retries=self.max_retries,
retry_delay=self.retry_delay,
retry_backoff_factor=self.retry_backoff_factor,
)
try:
return await self._poll_until_complete(client)
finally:
if owns_client:
await client.close()
def _display_text_on_node(self, text: str):
if not self.node_id:
return
if self.extracted_price is not None:
text = f"Price: ${self.extracted_price}\n{text}"
PromptServer.instance.send_progress_text(text, self.node_id)
def _display_time_progress_on_node(self, time_completed: int | float):
if not self.node_id:
return
if self.estimated_duration is not None:
remaining = max(0, int(self.estimated_duration) - time_completed)
message = f"Task in progress: {time_completed}s (~{remaining}s remaining)"
else:
message = f"Task in progress: {time_completed}s"
self._display_text_on_node(message)
def _check_task_status(self, response: R) -> TaskStatus:
try:
status = self.status_extractor(response)
if status in self.completed_statuses:
return TaskStatus.COMPLETED
if status in self.failed_statuses:
return TaskStatus.FAILED
return TaskStatus.PENDING
except Exception as e:
logging.error("Error extracting status: %s", e)
return TaskStatus.PENDING
async def _poll_until_complete(self, client: ApiClient) -> R:
"""Poll until the task is complete"""
consecutive_errors = 0
max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors
if self.progress_extractor:
progress = utils.ProgressBar(PROGRESS_BAR_MAX)
status = TaskStatus.PENDING
for poll_count in range(1, self.max_poll_attempts + 1):
try:
logging.debug("[DEBUG] Polling attempt #%s", poll_count)
request_dict = None if self.request is None else self.request.model_dump(exclude_none=True)
if poll_count == 1:
logging.debug(
"[DEBUG] Poll Request: %s %s",
self.poll_endpoint.method.value,
self.poll_endpoint.path,
)
logging.debug(
"[DEBUG] Poll Request Data: %s",
json.dumps(request_dict, indent=2) if request_dict else "None",
)
# Query task status
resp = await client.request(
self.poll_endpoint.method.value,
self.poll_endpoint.path,
params=self.poll_endpoint.query_params,
data=request_dict,
)
consecutive_errors = 0 # reset on success
response_obj: R = self.poll_endpoint.response_model.model_validate(resp)
# Check if task is complete
status = self._check_task_status(response_obj)
logging.debug("[DEBUG] Task Status: %s", status)
# If progress extractor is provided, extract progress
if self.progress_extractor:
new_progress = self.progress_extractor(response_obj)
if new_progress is not None:
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
if self.price_extractor:
price = self.price_extractor(response_obj)
if price is not None:
self.extracted_price = price
if status == TaskStatus.COMPLETED:
message = "Task completed successfully"
if self.result_url_extractor:
result_url = self.result_url_extractor(response_obj)
if result_url:
message = f"Result URL: {result_url}"
logging.debug("[DEBUG] %s", message)
self._display_text_on_node(message)
self.final_response = response_obj
if self.progress_extractor:
progress.update(100)
return self.final_response
if status == TaskStatus.FAILED:
message = f"Task failed: {json.dumps(resp)}"
logging.error("[DEBUG] %s", message)
raise Exception(message)
logging.debug("[DEBUG] Task still pending, continuing to poll...")
# Task pending wait
for i in range(int(self.poll_interval)):
self._display_time_progress_on_node((poll_count - 1) * self.poll_interval + i)
await asyncio.sleep(1)
except (LocalNetworkError, ApiServerError, NetworkError) as e:
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
raise Exception(
f"Polling aborted after {consecutive_errors} network errors: {str(e)}"
) from e
logging.warning(
"Network error (%s/%s): %s",
consecutive_errors,
max_consecutive_errors,
str(e),
)
await asyncio.sleep(self.poll_interval)
except Exception as e:
# For other errors, increment count and potentially abort
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED:
raise Exception(
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
) from e
logging.error("[DEBUG] Polling error: %s", str(e))
logging.warning(
"Error during polling (attempt %s/%s): %s. Will retry in %s seconds.",
poll_count,
self.max_poll_attempts,
str(e),
self.poll_interval,
)
await asyncio.sleep(self.poll_interval)
# If we've exhausted all polling attempts
raise Exception(
f"Polling timed out after {self.max_poll_attempts} attempts (" f"{self.max_poll_attempts * self.poll_interval} seconds). "
"The operation may still be running on the server but is taking longer than expected."
)

View File

@@ -1,120 +0,0 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
class MinimaxBaseResponse(BaseModel):
status_code: int = Field(
...,
description='Status code. 0 indicates success, other values indicate errors.',
)
status_msg: str = Field(
..., description='Specific error details or success message.'
)
class File(BaseModel):
bytes: Optional[int] = Field(None, description='File size in bytes')
created_at: Optional[int] = Field(
None, description='Unix timestamp when the file was created, in seconds'
)
download_url: Optional[str] = Field(
None, description='The URL to download the video'
)
backup_download_url: Optional[str] = Field(
None, description='The backup URL to download the video'
)
file_id: Optional[int] = Field(None, description='Unique identifier for the file')
filename: Optional[str] = Field(None, description='The name of the file')
purpose: Optional[str] = Field(None, description='The purpose of using the file')
class MinimaxFileRetrieveResponse(BaseModel):
base_resp: MinimaxBaseResponse
file: File
class MiniMaxModel(str, Enum):
T2V_01_Director = 'T2V-01-Director'
I2V_01_Director = 'I2V-01-Director'
S2V_01 = 'S2V-01'
I2V_01 = 'I2V-01'
I2V_01_live = 'I2V-01-live'
T2V_01 = 'T2V-01'
Hailuo_02 = 'MiniMax-Hailuo-02'
class Status6(str, Enum):
Queueing = 'Queueing'
Preparing = 'Preparing'
Processing = 'Processing'
Success = 'Success'
Fail = 'Fail'
class MinimaxTaskResultResponse(BaseModel):
base_resp: MinimaxBaseResponse
file_id: Optional[str] = Field(
None,
description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.',
)
status: Status6 = Field(
...,
description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).",
)
task_id: str = Field(..., description='The task ID being queried.')
class SubjectReferenceItem(BaseModel):
image: Optional[str] = Field(
None, description='URL or base64 encoding of the subject reference image.'
)
mask: Optional[str] = Field(
None,
description='URL or base64 encoding of the mask for the subject reference image.',
)
class MinimaxVideoGenerationRequest(BaseModel):
callback_url: Optional[str] = Field(
None,
description='Optional. URL to receive real-time status updates about the video generation task.',
)
first_frame_image: Optional[str] = Field(
None,
description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.',
)
model: MiniMaxModel = Field(
...,
description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01',
)
prompt: Optional[str] = Field(
None,
description='Description of the video. Should be less than 2000 characters. Supports camera movement instructions in [brackets].',
max_length=2000,
)
prompt_optimizer: Optional[bool] = Field(
True,
description='If true (default), the model will automatically optimize the prompt. Set to false for more precise control.',
)
subject_reference: Optional[list[SubjectReferenceItem]] = Field(
None,
description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.',
)
duration: Optional[int] = Field(
None,
description="The length of the output video in seconds."
)
resolution: Optional[str] = Field(
None,
description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels."
)
class MinimaxVideoGenerationResponse(BaseModel):
base_resp: MinimaxBaseResponse
task_id: str = Field(
..., description='The task ID for the asynchronous video generation task.'
)

View File

@@ -1,11 +1,11 @@
from __future__ import annotations
import os
import datetime
import hashlib
import json
import logging
import os
import re
import hashlib
from typing import Any
import folder_paths

View File

@@ -5,6 +5,10 @@ import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apinode_utils import (
resize_mask_to_image,
validate_aspect_ratio,
)
from comfy_api_nodes.apis.bfl_api import (
BFLFluxExpandImageRequest,
BFLFluxFillImageRequest,
@@ -19,10 +23,8 @@ from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_image_tensor,
poll_op,
resize_mask_to_image,
sync_op,
tensor_to_base64_string,
validate_aspect_ratio_string,
validate_string,
)
@@ -41,6 +43,11 @@ class FluxProUltraImageNode(IO.ComfyNode):
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
"""
MINIMUM_RATIO = 1 / 4
MAXIMUM_RATIO = 4 / 1
MINIMUM_RATIO_STR = "1:4"
MAXIMUM_RATIO_STR = "4:1"
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
@@ -105,7 +112,16 @@ class FluxProUltraImageNode(IO.ComfyNode):
@classmethod
def validate_inputs(cls, aspect_ratio: str):
validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1))
try:
validate_aspect_ratio(
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
)
except Exception as e:
return str(e)
return True
@classmethod
@@ -129,7 +145,13 @@ class FluxProUltraImageNode(IO.ComfyNode):
prompt=prompt,
prompt_upsampling=prompt_upsampling,
seed=seed,
aspect_ratio=aspect_ratio,
aspect_ratio=validate_aspect_ratio(
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
),
raw=raw,
image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)),
image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)),
@@ -158,6 +180,11 @@ class FluxKontextProImageNode(IO.ComfyNode):
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
"""
MINIMUM_RATIO = 1 / 4
MAXIMUM_RATIO = 4 / 1
MINIMUM_RATIO_STR = "1:4"
MAXIMUM_RATIO_STR = "4:1"
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
@@ -234,7 +261,13 @@ class FluxKontextProImageNode(IO.ComfyNode):
seed=0,
prompt_upsampling=False,
) -> IO.NodeOutput:
validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1))
aspect_ratio = validate_aspect_ratio(
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
)
if input_image is None:
validate_string(prompt, strip_whitespace=False)
initial_response = await sync_op(

View File

@@ -17,7 +17,7 @@ from comfy_api_nodes.util import (
poll_op,
sync_op,
upload_images_to_comfyapi,
validate_image_aspect_ratio,
validate_image_aspect_ratio_range,
validate_image_dimensions,
validate_string,
)
@@ -403,7 +403,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
validate_string(prompt, strip_whitespace=True, min_length=1)
if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.")
validate_image_aspect_ratio(image, (1, 3), (3, 1))
validate_image_aspect_ratio_range(image, (1, 3), (3, 1))
source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0]
payload = Image2ImageTaskCreationRequest(
model=model,
@@ -565,7 +565,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
reference_images_urls = []
if n_input_images:
for i in image:
validate_image_aspect_ratio(i, (1, 3), (3, 1))
validate_image_aspect_ratio_range(i, (1, 3), (3, 1))
reference_images_urls = await upload_images_to_comfyapi(
cls,
image,
@@ -798,7 +798,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
validate_string(prompt, strip_whitespace=True, min_length=1)
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
prompt = (
@@ -923,7 +923,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
for i in (first_frame, last_frame):
validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000)
validate_image_aspect_ratio(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
validate_image_aspect_ratio_range(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
download_urls = await upload_images_to_comfyapi(
cls,
@@ -1045,7 +1045,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"])
for image in images:
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png")
prompt = (

View File

@@ -1,6 +1,6 @@
from io import BytesIO
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import ComfyExtension, IO
from PIL import Image
import numpy as np
import torch
@@ -11,14 +11,20 @@ from comfy_api_nodes.apis import (
IdeogramV3Request,
IdeogramV3EditRequest,
)
from comfy_api_nodes.util import (
from comfy_api_nodes.apis.client import (
ApiEndpoint,
bytesio_to_image_tensor,
download_url_as_bytesio,
resize_mask_to_image,
sync_op,
HttpMethod,
SynchronousOperation,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_bytesio,
bytesio_to_image_tensor,
resize_mask_to_image,
)
from server import PromptServer
V1_V1_RES_MAP = {
"Auto":"AUTO",
"512 x 1536":"RESOLUTION_512_1536",
@@ -214,7 +220,7 @@ async def download_and_process_images(image_urls):
for image_url in image_urls:
# Using functions from apinode_utils.py to handle downloading and processing
image_bytesio = await download_url_as_bytesio(image_url) # Download image content to BytesIO
image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
image_tensors.append(img_tensor)
@@ -227,6 +233,19 @@ async def download_and_process_images(image_urls):
return stacked_tensors
def display_image_urls_on_node(image_urls, node_id):
if node_id and image_urls:
if len(image_urls) == 1:
PromptServer.instance.send_progress_text(
f"Generated Image URL:\n{image_urls[0]}", node_id
)
else:
urls_text = "Generated Image URLs:\n" + "\n".join(
f"{i+1}. {url}" for i, url in enumerate(image_urls)
)
PromptServer.instance.send_progress_text(urls_text, node_id)
class IdeogramV1(IO.ComfyNode):
@classmethod
@@ -315,30 +334,44 @@ class IdeogramV1(IO.ComfyNode):
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
model = "V_1_TURBO" if turbo else "V_1"
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
response_model=IdeogramGenerateResponse,
data=IdeogramGenerateRequest(
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
method=HttpMethod.POST,
request_model=IdeogramGenerateRequest,
response_model=IdeogramGenerateResponse,
),
request=IdeogramGenerateRequest(
image_request=ImageRequest(
prompt=prompt,
model=model,
num_images=num_images,
seed=seed,
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
magic_prompt_option=(
magic_prompt_option if magic_prompt_option != "AUTO" else None
),
negative_prompt=negative_prompt if negative_prompt else None,
)
),
max_retries=1,
auth_kwargs=auth,
)
response = await operation.execute()
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return IO.NodeOutput(await download_and_process_images(image_urls))
@@ -467,11 +500,18 @@ class IdeogramV2(IO.ComfyNode):
else:
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
response_model=IdeogramGenerateResponse,
data=IdeogramGenerateRequest(
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
method=HttpMethod.POST,
request_model=IdeogramGenerateRequest,
response_model=IdeogramGenerateResponse,
),
request=IdeogramGenerateRequest(
image_request=ImageRequest(
prompt=prompt,
model=model,
@@ -479,20 +519,28 @@ class IdeogramV2(IO.ComfyNode):
seed=seed,
aspect_ratio=final_aspect_ratio,
resolution=final_resolution,
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
magic_prompt_option=(
magic_prompt_option if magic_prompt_option != "AUTO" else None
),
style_type=style_type if style_type != "NONE" else None,
negative_prompt=negative_prompt if negative_prompt else None,
color_palette=color_palette if color_palette else None,
)
),
max_retries=1,
auth_kwargs=auth,
)
response = await operation.execute()
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return IO.NodeOutput(await download_and_process_images(image_urls))
@@ -608,6 +656,10 @@ class IdeogramV3(IO.ComfyNode):
character_image=None,
character_mask=None,
):
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
if rendering_speed == "BALANCED": # for backward compatibility
rendering_speed = "DEFAULT"
@@ -642,6 +694,9 @@ class IdeogramV3(IO.ComfyNode):
# Check if both image and mask are provided for editing mode
if image is not None and mask is not None:
# Edit mode
path = "/proxy/ideogram/ideogram-v3/edit"
# Process image and mask
input_tensor = image.squeeze().cpu()
# Resize mask to match image dimension
@@ -694,20 +749,27 @@ class IdeogramV3(IO.ComfyNode):
if character_mask_binary:
files["character_mask_binary"] = character_mask_binary
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/ideogram/ideogram-v3/edit", method="POST"),
response_model=IdeogramGenerateResponse,
data=edit_request,
# Execute the operation for edit mode
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=IdeogramV3EditRequest,
response_model=IdeogramGenerateResponse,
),
request=edit_request,
files=files,
content_type="multipart/form-data",
max_retries=1,
auth_kwargs=auth,
)
elif image is not None or mask is not None:
# If only one of image or mask is provided, raise an error
raise Exception("Ideogram V3 image editing requires both an image AND a mask")
else:
# Generation mode
path = "/proxy/ideogram/ideogram-v3/generate"
# Create generation request
gen_request = IdeogramV3Request(
prompt=prompt,
@@ -738,22 +800,32 @@ class IdeogramV3(IO.ComfyNode):
if files:
gen_request.style_type = "AUTO"
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/ideogram/ideogram-v3/generate", method="POST"),
response_model=IdeogramGenerateResponse,
data=gen_request,
# Execute the operation for generation mode
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=IdeogramV3Request,
response_model=IdeogramGenerateResponse,
),
request=gen_request,
files=files if files else None,
content_type="multipart/form-data",
max_retries=1,
auth_kwargs=auth,
)
# Execute the operation and process response
response = await operation.execute()
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return IO.NodeOutput(await download_and_process_images(image_urls))
@@ -766,6 +838,5 @@ class IdeogramExtension(ComfyExtension):
IdeogramV3,
]
async def comfy_entrypoint() -> IdeogramExtension:
return IdeogramExtension()

View File

@@ -282,7 +282,7 @@ def validate_input_image(image: torch.Tensor) -> None:
See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo
"""
validate_image_dimensions(image, min_width=300, min_height=300)
validate_image_aspect_ratio(image, (1, 2.5), (2.5, 1))
validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5)
def get_video_from_response(response) -> KlingVideoResult:

View File

@@ -46,7 +46,7 @@ class TextToVideoNode(IO.ComfyNode):
multiline=True,
default="",
),
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
IO.Combo.Input("duration", options=[6, 8, 10], default=8),
IO.Combo.Input(
"resolution",
options=[
@@ -85,10 +85,6 @@ class TextToVideoNode(IO.ComfyNode):
generate_audio: bool = False,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=10000)
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
raise ValueError(
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
)
response = await sync_op_raw(
cls,
ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"),
@@ -122,7 +118,7 @@ class ImageToVideoNode(IO.ComfyNode):
multiline=True,
default="",
),
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
IO.Combo.Input("duration", options=[6, 8, 10], default=8),
IO.Combo.Input(
"resolution",
options=[
@@ -162,10 +158,6 @@ class ImageToVideoNode(IO.ComfyNode):
generate_audio: bool = False,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=10000)
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
raise ValueError(
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
)
if get_number_of_images(image) != 1:
raise ValueError("Currently only one input image is supported.")
response = await sync_op_raw(

View File

@@ -1,51 +1,69 @@
from __future__ import annotations
from inspect import cleandoc
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis.luma_api import (
LumaAspectRatio,
LumaCharacterRef,
LumaConceptChain,
LumaGeneration,
LumaGenerationRequest,
LumaImageGenerationRequest,
LumaImageIdentity,
LumaImageModel,
LumaImageReference,
LumaIO,
LumaKeyframes,
LumaVideoModel,
LumaVideoOutputResolution,
LumaVideoModelOutputDuration,
LumaAspectRatio,
LumaState,
LumaImageGenerationRequest,
LumaGenerationRequest,
LumaGeneration,
LumaCharacterRef,
LumaModifyImageRef,
LumaImageIdentity,
LumaReference,
LumaReferenceChain,
LumaVideoModel,
LumaVideoModelOutputDuration,
LumaVideoOutputResolution,
LumaImageReference,
LumaKeyframes,
LumaConceptChain,
LumaIO,
get_luma_concepts,
)
from comfy_api_nodes.util import (
from comfy_api_nodes.apis.client import (
ApiEndpoint,
download_url_to_image_tensor,
download_url_to_video_output,
poll_op,
sync_op,
upload_images_to_comfyapi,
validate_string,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
upload_images_to_comfyapi,
process_image_response,
)
from server import PromptServer
from comfy_api_nodes.util import validate_string
import aiohttp
import torch
from io import BytesIO
LUMA_T2V_AVERAGE_DURATION = 105
LUMA_I2V_AVERAGE_DURATION = 100
def image_result_url_extractor(response: LumaGeneration):
return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None
def video_result_url_extractor(response: LumaGeneration):
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
class LumaReferenceNode(IO.ComfyNode):
"""
Holds an image and weight for use with Luma Generate Image node.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaReferenceNode",
display_name="Luma Reference",
category="api node/image/Luma",
description="Holds an image and weight for use with Luma Generate Image node.",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input(
"image",
@@ -65,10 +83,17 @@ class LumaReferenceNode(IO.ComfyNode):
),
],
outputs=[IO.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
)
@classmethod
def execute(cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None) -> IO.NodeOutput:
def execute(
cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
) -> IO.NodeOutput:
if luma_ref is not None:
luma_ref = luma_ref.clone()
else:
@@ -78,13 +103,17 @@ class LumaReferenceNode(IO.ComfyNode):
class LumaConceptsNode(IO.ComfyNode):
"""
Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaConceptsNode",
display_name="Luma Concepts",
category="api node/video/Luma",
description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Combo.Input(
"concept1",
@@ -109,6 +138,11 @@ class LumaConceptsNode(IO.ComfyNode):
),
],
outputs=[IO.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
)
@classmethod
@@ -127,13 +161,17 @@ class LumaConceptsNode(IO.ComfyNode):
class LumaImageGenerationNode(IO.ComfyNode):
"""
Generates images synchronously based on prompt and aspect ratio.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaImageNode",
display_name="Luma Text to Image",
category="api node/image/Luma",
description="Generates images synchronously based on prompt and aspect ratio.",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
"prompt",
@@ -199,30 +237,45 @@ class LumaImageGenerationNode(IO.ComfyNode):
aspect_ratio: str,
seed,
style_image_weight: float,
image_luma_ref: Optional[LumaReferenceChain] = None,
style_image: Optional[torch.Tensor] = None,
character_image: Optional[torch.Tensor] = None,
image_luma_ref: LumaReferenceChain = None,
style_image: torch.Tensor = None,
character_image: torch.Tensor = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=3)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# handle image_luma_ref
api_image_ref = None
if image_luma_ref is not None:
api_image_ref = await cls._convert_luma_refs(image_luma_ref, max_refs=4)
api_image_ref = await cls._convert_luma_refs(
image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs,
)
# handle style_luma_ref
api_style_ref = None
if style_image is not None:
api_style_ref = await cls._convert_style_image(style_image, weight=style_image_weight)
api_style_ref = await cls._convert_style_image(
style_image, weight=style_image_weight, auth_kwargs=auth_kwargs,
)
# handle character_ref images
character_ref = None
if character_image is not None:
download_urls = await upload_images_to_comfyapi(cls, character_image, max_images=4)
character_ref = LumaCharacterRef(identity0=LumaImageIdentity(images=download_urls))
download_urls = await upload_images_to_comfyapi(
character_image, max_images=4, auth_kwargs=auth_kwargs,
)
character_ref = LumaCharacterRef(
identity0=LumaImageIdentity(images=download_urls)
)
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
response_model=LumaGeneration,
data=LumaImageGenerationRequest(
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations/image",
method=HttpMethod.POST,
request_model=LumaImageGenerationRequest,
response_model=LumaGeneration,
),
request=LumaImageGenerationRequest(
prompt=prompt,
model=model,
aspect_ratio=aspect_ratio,
@@ -230,21 +283,41 @@ class LumaImageGenerationNode(IO.ComfyNode):
style_ref=api_style_ref,
character_ref=character_ref,
),
auth_kwargs=auth_kwargs,
)
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
response_model=LumaGeneration,
response_api: LumaGeneration = await operation.execute()
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=cls.hidden.unique_id,
auth_kwargs=auth_kwargs,
)
return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return IO.NodeOutput(img)
@classmethod
async def _convert_luma_refs(cls, luma_ref: LumaReferenceChain, max_refs: int):
async def _convert_luma_refs(
cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
):
luma_urls = []
ref_count = 0
for ref in luma_ref.refs:
download_urls = await upload_images_to_comfyapi(cls, ref.image, max_images=1)
download_urls = await upload_images_to_comfyapi(
ref.image, max_images=1, auth_kwargs=auth_kwargs
)
luma_urls.append(download_urls[0])
ref_count += 1
if ref_count >= max_refs:
@@ -252,19 +325,27 @@ class LumaImageGenerationNode(IO.ComfyNode):
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
@classmethod
async def _convert_style_image(cls, style_image: torch.Tensor, weight: float):
chain = LumaReferenceChain(first_ref=LumaReference(image=style_image, weight=weight))
return await cls._convert_luma_refs(chain, max_refs=1)
async def _convert_style_image(
cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
):
chain = LumaReferenceChain(
first_ref=LumaReference(image=style_image, weight=weight)
)
return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
class LumaImageModifyNode(IO.ComfyNode):
"""
Modifies images synchronously based on prompt and aspect ratio.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaImageModifyNode",
display_name="Luma Image to Image",
category="api node/image/Luma",
description="Modifies images synchronously based on prompt and aspect ratio.",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input(
"image",
@@ -314,37 +395,68 @@ class LumaImageModifyNode(IO.ComfyNode):
image_weight: float,
seed,
) -> IO.NodeOutput:
download_urls = await upload_images_to_comfyapi(cls, image, max_images=1)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# first, upload image
download_urls = await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=auth_kwargs,
)
image_url = download_urls[0]
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
response_model=LumaGeneration,
data=LumaImageGenerationRequest(
# next, make Luma call with download url provided
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations/image",
method=HttpMethod.POST,
request_model=LumaImageGenerationRequest,
response_model=LumaGeneration,
),
request=LumaImageGenerationRequest(
prompt=prompt,
model=model,
modify_image_ref=LumaModifyImageRef(
url=image_url, weight=round(max(min(1.0 - image_weight, 0.98), 0.0), 2)
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
),
),
auth_kwargs=auth_kwargs,
)
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
response_model=LumaGeneration,
response_api: LumaGeneration = await operation.execute()
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=cls.hidden.unique_id,
auth_kwargs=auth_kwargs,
)
return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return IO.NodeOutput(img)
class LumaTextToVideoGenerationNode(IO.ComfyNode):
"""
Generates videos synchronously based on prompt and output_size.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaVideoNode",
display_name="Luma Text to Video",
category="api node/video/Luma",
description="Generates videos synchronously based on prompt and output_size.",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
"prompt",
@@ -386,7 +498,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
"luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True,
),
)
],
outputs=[IO.Video.Output()],
hidden=[
@@ -407,17 +519,24 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
duration: str,
loop: bool,
seed,
luma_concepts: Optional[LumaConceptChain] = None,
luma_concepts: LumaConceptChain = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=3)
duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/luma/generations", method="POST"),
response_model=LumaGeneration,
data=LumaGenerationRequest(
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations",
method=HttpMethod.POST,
request_model=LumaGenerationRequest,
response_model=LumaGeneration,
),
request=LumaGenerationRequest(
prompt=prompt,
model=model,
resolution=resolution,
@@ -426,25 +545,47 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
loop=loop,
concepts=luma_concepts.create_api_model() if luma_concepts else None,
),
auth_kwargs=auth_kwargs,
)
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
response_model=LumaGeneration,
response_api: LumaGeneration = await operation.execute()
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=cls.hidden.unique_id,
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
auth_kwargs=auth_kwargs,
)
return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class LumaImageToVideoGenerationNode(IO.ComfyNode):
"""
Generates videos synchronously based on prompt, input images, and output_size.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaImageToVideoNode",
display_name="Luma Image to Video",
category="api node/video/Luma",
description="Generates videos synchronously based on prompt, input images, and output_size.",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
"prompt",
@@ -496,7 +637,7 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
"luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True,
),
)
],
outputs=[IO.Video.Output()],
hidden=[
@@ -521,15 +662,25 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
luma_concepts: LumaConceptChain = None,
) -> IO.NodeOutput:
if first_image is None and last_image is None:
raise Exception("At least one of first_image and last_image requires an input.")
keyframes = await cls._convert_to_keyframes(first_image, last_image)
raise Exception(
"At least one of first_image and last_image requires an input."
)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs)
duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/luma/generations", method="POST"),
response_model=LumaGeneration,
data=LumaGenerationRequest(
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations",
method=HttpMethod.POST,
request_model=LumaGenerationRequest,
response_model=LumaGeneration,
),
request=LumaGenerationRequest(
prompt=prompt,
model=model,
aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason
@@ -539,31 +690,54 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
keyframes=keyframes,
concepts=luma_concepts.create_api_model() if luma_concepts else None,
),
auth_kwargs=auth_kwargs,
)
response_poll = await poll_op(
cls,
poll_endpoint=ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
response_model=LumaGeneration,
response_api: LumaGeneration = await operation.execute()
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=cls.hidden.unique_id,
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
auth_kwargs=auth_kwargs,
)
return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
@classmethod
async def _convert_to_keyframes(
cls,
first_image: torch.Tensor = None,
last_image: torch.Tensor = None,
auth_kwargs: Optional[dict[str,str]] = None,
):
if first_image is None and last_image is None:
return None
frame0 = None
frame1 = None
if first_image is not None:
download_urls = await upload_images_to_comfyapi(cls, first_image, max_images=1)
download_urls = await upload_images_to_comfyapi(
first_image, max_images=1, auth_kwargs=auth_kwargs,
)
frame0 = LumaImageReference(type="image", url=download_urls[0])
if last_image is not None:
download_urls = await upload_images_to_comfyapi(cls, last_image, max_images=1)
download_urls = await upload_images_to_comfyapi(
last_image, max_images=1, auth_kwargs=auth_kwargs,
)
frame1 = LumaImageReference(type="image", url=download_urls[0])
return LumaKeyframes(frame0=frame0, frame1=frame1)

View File

@@ -1,57 +1,71 @@
from inspect import cleandoc
from typing import Optional
import logging
import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.minimax_api import (
MinimaxFileRetrieveResponse,
MiniMaxModel,
MinimaxTaskResultResponse,
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis import (
MinimaxVideoGenerationRequest,
MinimaxVideoGenerationResponse,
MinimaxFileRetrieveResponse,
MinimaxTaskResultResponse,
SubjectReferenceItem,
MiniMaxModel,
)
from comfy_api_nodes.util import (
from comfy_api_nodes.apis.client import (
ApiEndpoint,
download_url_to_video_output,
poll_op,
sync_op,
upload_images_to_comfyapi,
validate_string,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_bytesio,
upload_images_to_comfyapi,
)
from comfy_api_nodes.util import validate_string
from server import PromptServer
I2V_AVERAGE_DURATION = 114
T2V_AVERAGE_DURATION = 234
async def _generate_mm_video(
cls: type[IO.ComfyNode],
*,
auth: dict[str, str],
node_id: str,
prompt_text: str,
seed: int,
model: str,
image: Optional[torch.Tensor] = None, # used for ImageToVideo
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
image: Optional[torch.Tensor] = None, # used for ImageToVideo
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
average_duration: Optional[int] = None,
) -> IO.NodeOutput:
if image is None:
validate_string(prompt_text, field_name="prompt_text")
# upload image, if passed in
image_url = None
if image is not None:
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth))[0]
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
subject_reference = None
if subject is not None:
subject_url = (await upload_images_to_comfyapi(cls, subject, max_images=1))[0]
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=auth))[0]
subject_reference = [SubjectReferenceItem(image=subject_url)]
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
response_model=MinimaxVideoGenerationResponse,
data=MinimaxVideoGenerationRequest(
video_generate_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/video_generation",
method=HttpMethod.POST,
request_model=MinimaxVideoGenerationRequest,
response_model=MinimaxVideoGenerationResponse,
),
request=MinimaxVideoGenerationRequest(
model=MiniMaxModel(model),
prompt=prompt_text,
callback_url=None,
@@ -59,50 +73,81 @@ async def _generate_mm_video(
subject_reference=subject_reference,
prompt_optimizer=None,
),
auth_kwargs=auth,
)
response = await video_generate_operation.execute()
task_id = response.task_id
if not task_id:
raise Exception(f"MiniMax generation failed: {response.base_resp}")
task_result = await poll_op(
cls,
ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
response_model=MinimaxTaskResultResponse,
video_generate_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/minimax/query/video_generation",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxTaskResultResponse,
query_params={"task_id": task_id},
),
completed_statuses=["Success"],
failed_statuses=["Fail"],
status_extractor=lambda x: x.status.value,
estimated_duration=average_duration,
node_id=node_id,
auth_kwargs=auth,
)
task_result = await video_generate_operation.execute()
file_id = task_result.file_id
if file_id is None:
raise Exception("Request was not successful. Missing file ID.")
file_result = await sync_op(
cls,
ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
response_model=MinimaxFileRetrieveResponse,
file_retrieve_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/files/retrieve",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxFileRetrieveResponse,
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=auth,
)
file_result = await file_retrieve_operation.execute()
file_url = file_result.file.download_url
if file_url is None:
raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
if file_result.file.backup_download_url:
try:
return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
except Exception: # if we have a second URL to retrieve the result, try again using that one
return IO.NodeOutput(
await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
)
return IO.NodeOutput(await download_url_to_video_output(file_url))
raise Exception(
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info("Generated video URL: %s", file_url)
if node_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, node_id)
# Download and return as VideoFromFile
video_io = await download_url_to_bytesio(file_url)
if video_io is None:
error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg)
raise Exception(error_msg)
return IO.NodeOutput(VideoFromFile(video_io))
class MinimaxTextToVideoNode(IO.ComfyNode):
"""
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MinimaxTextToVideoNode",
display_name="MiniMax Text to Video",
category="api node/video/MiniMax",
description="Generates videos synchronously based on a prompt, and optional parameters.",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
"prompt_text",
@@ -144,7 +189,11 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
seed: int = 0,
) -> IO.NodeOutput:
return await _generate_mm_video(
cls,
auth={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
prompt_text=prompt_text,
seed=seed,
model=model,
@@ -155,13 +204,17 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
class MinimaxImageToVideoNode(IO.ComfyNode):
"""
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MinimaxImageToVideoNode",
display_name="MiniMax Image to Video",
category="api node/video/MiniMax",
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input(
"image",
@@ -208,7 +261,11 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
seed: int = 0,
) -> IO.NodeOutput:
return await _generate_mm_video(
cls,
auth={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
prompt_text=prompt_text,
seed=seed,
model=model,
@@ -219,13 +276,17 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
class MinimaxSubjectToVideoNode(IO.ComfyNode):
"""
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MinimaxSubjectToVideoNode",
display_name="MiniMax Subject to Video",
category="api node/video/MiniMax",
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input(
"subject",
@@ -272,7 +333,11 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode):
seed: int = 0,
) -> IO.NodeOutput:
return await _generate_mm_video(
cls,
auth={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
prompt_text=prompt_text,
seed=seed,
model=model,
@@ -283,13 +348,15 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode):
class MinimaxHailuoVideoNode(IO.ComfyNode):
"""Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MinimaxHailuoVideoNode",
display_name="MiniMax Hailuo Video",
category="api node/video/MiniMax",
description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
"prompt_text",
@@ -353,6 +420,10 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
resolution: str = "768P",
model: str = "MiniMax-Hailuo-02",
) -> IO.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
if first_frame_image is None:
validate_string(prompt_text, field_name="prompt_text")
@@ -364,13 +435,16 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
# upload image, if passed in
image_url = None
if first_frame_image is not None:
image_url = (await upload_images_to_comfyapi(cls, first_frame_image, max_images=1))[0]
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=auth))[0]
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
response_model=MinimaxVideoGenerationResponse,
data=MinimaxVideoGenerationRequest(
video_generate_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/video_generation",
method=HttpMethod.POST,
request_model=MinimaxVideoGenerationRequest,
response_model=MinimaxVideoGenerationResponse,
),
request=MinimaxVideoGenerationRequest(
model=MiniMaxModel(model),
prompt=prompt_text,
callback_url=None,
@@ -379,42 +453,67 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
duration=duration,
resolution=resolution,
),
auth_kwargs=auth,
)
response = await video_generate_operation.execute()
task_id = response.task_id
if not task_id:
raise Exception(f"MiniMax generation failed: {response.base_resp}")
average_duration = 120 if resolution == "768P" else 240
task_result = await poll_op(
cls,
ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
response_model=MinimaxTaskResultResponse,
video_generate_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/minimax/query/video_generation",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxTaskResultResponse,
query_params={"task_id": task_id},
),
completed_statuses=["Success"],
failed_statuses=["Fail"],
status_extractor=lambda x: x.status.value,
estimated_duration=average_duration,
node_id=cls.hidden.unique_id,
auth_kwargs=auth,
)
task_result = await video_generate_operation.execute()
file_id = task_result.file_id
if file_id is None:
raise Exception("Request was not successful. Missing file ID.")
file_result = await sync_op(
cls,
ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
response_model=MinimaxFileRetrieveResponse,
file_retrieve_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/files/retrieve",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxFileRetrieveResponse,
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=auth,
)
file_result = await file_retrieve_operation.execute()
file_url = file_result.file.download_url
if file_url is None:
raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
raise Exception(
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info("Generated video URL: %s", file_url)
if cls.hidden.unique_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, cls.hidden.unique_id)
if file_result.file.backup_download_url:
try:
return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
except Exception: # if we have a second URL to retrieve the result, try again using that one
return IO.NodeOutput(
await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
)
return IO.NodeOutput(await download_url_to_video_output(file_url))
video_io = await download_url_to_bytesio(file_url)
if video_io is None:
error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg)
raise Exception(error_msg)
return IO.NodeOutput(VideoFromFile(video_io))
class MinimaxExtension(ComfyExtension):

File diff suppressed because it is too large Load Diff

View File

@@ -7,23 +7,24 @@ from __future__ import annotations
from io import BytesIO
import logging
from typing import Optional
from typing import Optional, TypeVar
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
from comfy_api_nodes.apis import pika_api as pika_defs
from comfy_api_nodes.util import (
validate_string,
download_url_to_video_output,
tensor_to_bytesio,
from comfy_api_nodes.apis import pika_defs
from comfy_api_nodes.apis.client import (
ApiEndpoint,
sync_op,
poll_op,
EmptyRequest,
HttpMethod,
PollingOperation,
SynchronousOperation,
)
from comfy_api_nodes.util import validate_string, download_url_to_video_output, tensor_to_bytesio
R = TypeVar("R")
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
@@ -39,18 +40,28 @@ PATH_VIDEO_GET = "/proxy/pika/videos"
async def execute_task(
task_id: str,
cls: type[IO.ComfyNode],
initial_operation: SynchronousOperation[R, pika_defs.PikaGenerateResponse],
auth_kwargs: Optional[dict[str, str]] = None,
node_id: Optional[str] = None,
) -> IO.NodeOutput:
final_response: pika_defs.PikaVideoResponse = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
response_model=pika_defs.PikaVideoResponse,
task_id = (await initial_operation.execute()).video_id
final_response: pika_defs.PikaVideoResponse = await PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"{PATH_VIDEO_GET}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=pika_defs.PikaVideoResponse,
),
completed_statuses=["finished"],
failed_statuses=["failed", "cancelled"],
status_extractor=lambda response: (response.status.value if response.status else None),
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
auth_kwargs=auth_kwargs,
result_url_extractor=lambda response: (response.url if hasattr(response, "url") else None),
node_id=node_id,
estimated_duration=60,
max_poll_attempts=240,
)
).execute()
if not final_response.url:
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
logging.error(error_msg)
@@ -113,15 +124,23 @@ class PikaImageToVideo(IO.ComfyNode):
resolution=resolution,
duration=duration,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation.video_id, cls)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
class PikaTextToVideoNode(IO.ComfyNode):
@@ -164,11 +183,18 @@ class PikaTextToVideoNode(IO.ComfyNode):
duration: int,
aspect_ratio: float,
) -> IO.NodeOutput:
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_TEXT_TO_VIDEO,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
@@ -176,9 +202,10 @@ class PikaTextToVideoNode(IO.ComfyNode):
duration=duration,
aspectRatio=aspect_ratio,
),
auth_kwargs=auth,
content_type="application/x-www-form-urlencoded",
)
return await execute_task(initial_operation.video_id, cls)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
class PikaScenes(IO.ComfyNode):
@@ -282,16 +309,24 @@ class PikaScenes(IO.ComfyNode):
duration=duration,
aspectRatio=aspect_ratio,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKASCENES,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation.video_id, cls)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
class PikAdditionsNode(IO.ComfyNode):
@@ -348,16 +383,24 @@ class PikAdditionsNode(IO.ComfyNode):
negativePrompt=negative_prompt,
seed=seed,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKADDITIONS,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation.video_id, cls)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
class PikaSwapsNode(IO.ComfyNode):
@@ -429,15 +472,23 @@ class PikaSwapsNode(IO.ComfyNode):
seed=seed,
modifyRegionRoi=region_to_modify if region_to_modify else None,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKASWAPS,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation.video_id, cls)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
class PikaffectsNode(IO.ComfyNode):
@@ -477,11 +528,18 @@ class PikaffectsNode(IO.ComfyNode):
negative_prompt: str,
seed: int,
) -> IO.NodeOutput:
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKAFFECTS,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
pikaffect=pikaffect,
promptText=prompt_text,
negativePrompt=negative_prompt,
@@ -489,8 +547,9 @@ class PikaffectsNode(IO.ComfyNode):
),
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation.video_id, cls)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
class PikaStartEndFrameNode(IO.ComfyNode):
@@ -533,11 +592,18 @@ class PikaStartEndFrameNode(IO.ComfyNode):
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
]
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKAFRAMES,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
@@ -546,8 +612,9 @@ class PikaStartEndFrameNode(IO.ComfyNode):
),
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation.video_id, cls)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
class PikaApiNodesExtension(ComfyExtension):

View File

@@ -1,6 +1,7 @@
import torch
from inspect import cleandoc
from typing import Optional
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from io import BytesIO
from comfy_api_nodes.apis.pixverse_api import (
PixverseTextVideoRequest,
PixverseImageVideoRequest,
@@ -16,30 +17,53 @@ from comfy_api_nodes.apis.pixverse_api import (
PixverseIO,
pixverse_templates,
)
from comfy_api_nodes.util import (
from comfy_api_nodes.apis.client import (
ApiEndpoint,
download_url_to_video_output,
poll_op,
sync_op,
tensor_to_bytesio,
validate_string,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.util import validate_string, tensor_to_bytesio
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, IO
import torch
import aiohttp
AVERAGE_DURATION_T2V = 32
AVERAGE_DURATION_I2V = 30
AVERAGE_DURATION_T2T = 52
async def upload_image_to_pixverse(cls: type[IO.ComfyNode], image: torch.Tensor):
response_upload = await sync_op(
cls,
ApiEndpoint(path="/proxy/pixverse/image/upload", method="POST"),
response_model=PixverseImageUploadResponse,
def get_video_url_from_response(
response: PixverseGenerationStatusResponse,
) -> Optional[str]:
if response.Resp is None or response.Resp.url is None:
return None
return str(response.Resp.url)
async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
# first, upload image to Pixverse and get image id to use in actual generation call
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/image/upload",
method=HttpMethod.POST,
request_model=EmptyRequest,
response_model=PixverseImageUploadResponse,
),
request=EmptyRequest(),
files={"image": tensor_to_bytesio(image)},
content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
)
response_upload: PixverseImageUploadResponse = await operation.execute()
if response_upload.Resp is None:
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
return response_upload.Resp.img_id
@@ -69,13 +93,17 @@ class PixverseTemplateNode(IO.ComfyNode):
class PixverseTextToVideoNode(IO.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PixverseTextToVideoNode",
display_name="PixVerse Text to Video",
category="api node/video/PixVerse",
description="Generates videos based on prompt and output_size.",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
"prompt",
@@ -142,7 +170,7 @@ class PixverseTextToVideoNode(IO.ComfyNode):
negative_prompt: str = None,
pixverse_template: int = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=1)
validate_string(prompt, strip_whitespace=False)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
if quality == PixverseQuality.res_1080p:
@@ -151,11 +179,18 @@ class PixverseTextToVideoNode(IO.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/pixverse/video/text/generate", method="POST"),
response_model=PixverseVideoResponse,
data=PixverseTextVideoRequest(
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/text/generate",
method=HttpMethod.POST,
request_model=PixverseTextVideoRequest,
response_model=PixverseVideoResponse,
),
request=PixverseTextVideoRequest(
prompt=prompt,
aspect_ratio=aspect_ratio,
quality=quality,
@@ -165,14 +200,20 @@ class PixverseTextToVideoNode(IO.ComfyNode):
template_id=pixverse_template,
seed=seed,
),
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
response_model=PixverseGenerationStatusResponse,
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful],
failed_statuses=[
PixverseStatus.contents_moderation,
@@ -180,19 +221,30 @@ class PixverseTextToVideoNode(IO.ComfyNode):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixverseImageToVideoNode(IO.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PixverseImageToVideoNode",
display_name="PixVerse Image to Video",
category="api node/video/PixVerse",
description="Generates videos based on prompt and output_size.",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("image"),
IO.String.Input(
@@ -257,7 +309,11 @@ class PixverseImageToVideoNode(IO.ComfyNode):
pixverse_template: int = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
img_id = await upload_image_to_pixverse(cls, image)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
img_id = await upload_image_to_pixverse(image, auth_kwargs=auth)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@@ -267,11 +323,14 @@ class PixverseImageToVideoNode(IO.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/pixverse/video/img/generate", method="POST"),
response_model=PixverseVideoResponse,
data=PixverseImageVideoRequest(
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/img/generate",
method=HttpMethod.POST,
request_model=PixverseImageVideoRequest,
response_model=PixverseVideoResponse,
),
request=PixverseImageVideoRequest(
img_id=img_id,
prompt=prompt,
quality=quality,
@@ -281,15 +340,20 @@ class PixverseImageToVideoNode(IO.ComfyNode):
template_id=pixverse_template,
seed=seed,
),
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
response_model=PixverseGenerationStatusResponse,
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful],
failed_statuses=[
PixverseStatus.contents_moderation,
@@ -297,19 +361,30 @@ class PixverseImageToVideoNode(IO.ComfyNode):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V,
)
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixverseTransitionVideoNode(IO.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PixverseTransitionVideoNode",
display_name="PixVerse Transition Video",
category="api node/video/PixVerse",
description="Generates videos based on prompt and output_size.",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("first_frame"),
IO.Image.Input("last_frame"),
@@ -370,8 +445,12 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
negative_prompt: str = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
first_frame_id = await upload_image_to_pixverse(cls, first_frame)
last_frame_id = await upload_image_to_pixverse(cls, last_frame)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth)
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@@ -381,11 +460,14 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/pixverse/video/transition/generate", method="POST"),
response_model=PixverseVideoResponse,
data=PixverseTransitionVideoRequest(
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/transition/generate",
method=HttpMethod.POST,
request_model=PixverseTransitionVideoRequest,
response_model=PixverseVideoResponse,
),
request=PixverseTransitionVideoRequest(
first_frame_img=first_frame_id,
last_frame_img=last_frame_id,
prompt=prompt,
@@ -395,15 +477,20 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
negative_prompt=negative_prompt if negative_prompt else None,
seed=seed,
),
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
response_model=PixverseGenerationStatusResponse,
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful],
failed_statuses=[
PixverseStatus.contents_moderation,
@@ -411,9 +498,16 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixVerseExtension(ComfyExtension):

View File

@@ -8,6 +8,9 @@ from typing_extensions import override
from comfy.utils import ProgressBar
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apinode_utils import (
resize_mask_to_image,
)
from comfy_api_nodes.apis.recraft_api import (
RecraftColor,
RecraftColorChain,
@@ -25,7 +28,6 @@ from comfy_api_nodes.util import (
ApiEndpoint,
bytesio_to_image_tensor,
download_url_as_bytesio,
resize_mask_to_image,
sync_op,
tensor_to_bytesio,
validate_string,

View File

@@ -5,9 +5,12 @@ Rodin API docs: https://developer.hyper3d.ai/
"""
from __future__ import annotations
from inspect import cleandoc
import folder_paths as comfy_paths
import aiohttp
import os
import asyncio
import logging
import math
from typing import Optional
@@ -23,11 +26,11 @@ from comfy_api_nodes.apis.rodin_api import (
Rodin3DDownloadResponse,
JobStatus,
)
from comfy_api_nodes.util import (
sync_op,
poll_op,
from comfy_api_nodes.apis.client import (
ApiEndpoint,
download_url_to_bytesio,
HttpMethod,
SynchronousOperation,
PollingOperation,
)
from comfy_api.latest import ComfyExtension, IO
@@ -118,31 +121,35 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
async def create_generate_task(
cls: type[IO.ComfyNode],
images=None,
seed=1,
material="PBR",
quality_override=18000,
tier="Regular",
mesh_mode="Quad",
ta_pose: bool = False,
TAPose = False,
auth_kwargs: Optional[dict[str, str]] = None,
):
if images is None:
raise Exception("Rodin 3D generate requires at least 1 image.")
if len(images) > 5:
raise Exception("Rodin 3D generate requires up to 5 image.")
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/rodin/api/v2/rodin", method="POST"),
response_model=Rodin3DGenerateResponse,
data=Rodin3DGenerateRequest(
path = "/proxy/rodin/api/v2/rodin"
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=Rodin3DGenerateRequest,
response_model=Rodin3DGenerateResponse,
),
request=Rodin3DGenerateRequest(
seed=seed,
tier=tier,
material=material,
quality_override=quality_override,
mesh_mode=mesh_mode,
TAPose=ta_pose,
TAPose=TAPose,
),
files=[
(
@@ -152,8 +159,11 @@ async def create_generate_task(
for image in images if image is not None
],
content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
)
response = await operation.execute()
if hasattr(response, "error"):
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
logging.error(error_message)
@@ -177,46 +187,75 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
return "DONE"
return "Generating"
def extract_progress(response: Rodin3DCheckStatusResponse) -> Optional[int]:
if not response.jobs:
return None
completed_count = sum(1 for job in response.jobs if job.status == JobStatus.Done)
return int((completed_count / len(response.jobs)) * 100)
async def poll_for_task_status(subscription_key: str, cls: type[IO.ComfyNode]) -> Rodin3DCheckStatusResponse:
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
return await poll_op(
cls,
ApiEndpoint(path="/proxy/rodin/api/v2/status", method="POST"),
response_model=Rodin3DCheckStatusResponse,
data=Rodin3DCheckStatusRequest(subscription_key=subscription_key),
async def poll_for_task_status(
subscription_key, auth_kwargs: Optional[dict[str, str]] = None,
) -> Rodin3DCheckStatusResponse:
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/rodin/api/v2/status",
method=HttpMethod.POST,
request_model=Rodin3DCheckStatusRequest,
response_model=Rodin3DCheckStatusResponse,
),
request=Rodin3DCheckStatusRequest(subscription_key=subscription_key),
completed_statuses=["DONE"],
failed_statuses=["FAILED"],
status_extractor=check_rodin_status,
progress_extractor=extract_progress,
poll_interval=3.0,
auth_kwargs=auth_kwargs,
)
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
return await poll_operation.execute()
async def get_rodin_download_list(uuid: str, cls: type[IO.ComfyNode]) -> Rodin3DDownloadResponse:
async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] = None) -> Rodin3DDownloadResponse:
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
return await sync_op(
cls,
ApiEndpoint(path="/proxy/rodin/api/v2/download", method="POST"),
response_model=Rodin3DDownloadResponse,
data=Rodin3DDownloadRequest(task_uuid=uuid),
monitor_progress=False,
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/rodin/api/v2/download",
method=HttpMethod.POST,
request_model=Rodin3DDownloadRequest,
response_model=Rodin3DDownloadResponse,
),
request=Rodin3DDownloadRequest(task_uuid=uuid),
auth_kwargs=auth_kwargs,
)
return await operation.execute()
async def download_files(url_list, task_uuid: str):
result_folder_name = f"Rodin3D_{task_uuid}"
save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name)
async def download_files(url_list, task_uuid):
save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}")
os.makedirs(save_path, exist_ok=True)
model_file_path = None
for i in url_list.list:
file_path = os.path.join(save_path, i.name)
if file_path.endswith(".glb"):
model_file_path = os.path.join(result_folder_name, i.name)
await download_url_to_bytesio(i.url, file_path)
async with aiohttp.ClientSession() as session:
for i in url_list.list:
url = i.url
file_name = i.name
file_path = os.path.join(save_path, file_name)
if file_path.endswith(".glb"):
model_file_path = file_path
logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path)
max_retries = 5
for attempt in range(max_retries):
try:
async with session.get(url) as resp:
resp.raise_for_status()
with open(file_path, "wb") as f:
async for chunk in resp.content.iter_chunked(32 * 1024):
f.write(chunk)
break
except Exception as e:
logging.info("[ Rodin3D API - download_files ] Error downloading %s:%s", file_path, str(e))
if attempt < max_retries - 1:
logging.info("Retrying...")
await asyncio.sleep(2)
else:
logging.info(
"[ Rodin3D API - download_files ] Failed to download %s after %s attempts.",
file_path,
max_retries,
)
return model_file_path
@@ -238,7 +277,6 @@ class Rodin3D_Regular(IO.ComfyNode):
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -257,17 +295,21 @@ class Rodin3D_Regular(IO.ComfyNode):
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality_override = get_quality_mode(Polygon_count)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
task_uuid, subscription_key = await create_generate_task(
cls,
images=m_images,
seed=Seed,
material=Material_Type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
auth_kwargs=auth,
)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
await poll_for_task_status(subscription_key, auth_kwargs=auth)
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
model = await download_files(download_list, task_uuid)
return IO.NodeOutput(model)
@@ -291,7 +333,6 @@ class Rodin3D_Detail(IO.ComfyNode):
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -310,17 +351,21 @@ class Rodin3D_Detail(IO.ComfyNode):
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality_override = get_quality_mode(Polygon_count)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
task_uuid, subscription_key = await create_generate_task(
cls,
images=m_images,
seed=Seed,
material=Material_Type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
auth_kwargs=auth,
)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
await poll_for_task_status(subscription_key, auth_kwargs=auth)
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
model = await download_files(download_list, task_uuid)
return IO.NodeOutput(model)
@@ -344,7 +389,6 @@ class Rodin3D_Smooth(IO.ComfyNode):
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -357,22 +401,27 @@ class Rodin3D_Smooth(IO.ComfyNode):
Material_Type,
Polygon_count,
) -> IO.NodeOutput:
tier = "Smooth"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality_override = get_quality_mode(Polygon_count)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
task_uuid, subscription_key = await create_generate_task(
cls,
images=m_images,
seed=Seed,
material=Material_Type,
quality_override=quality_override,
tier="Smooth",
tier=tier,
mesh_mode=mesh_mode,
auth_kwargs=auth,
)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
await poll_for_task_status(subscription_key, auth_kwargs=auth)
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
model = await download_files(download_list, task_uuid)
return IO.NodeOutput(model)
@@ -403,7 +452,6 @@ class Rodin3D_Sketch(IO.ComfyNode):
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -414,21 +462,29 @@ class Rodin3D_Sketch(IO.ComfyNode):
Images,
Seed,
) -> IO.NodeOutput:
tier = "Sketch"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
material_type = "PBR"
quality_override = 18000
mesh_mode = "Quad"
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
task_uuid, subscription_key = await create_generate_task(
cls,
images=m_images,
seed=Seed,
material="PBR",
quality_override=18000,
tier="Sketch",
mesh_mode="Quad",
material=material_type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
auth_kwargs=auth,
)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
await poll_for_task_status(subscription_key, auth_kwargs=auth)
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
model = await download_files(download_list, task_uuid)
return IO.NodeOutput(model)
@@ -467,7 +523,6 @@ class Rodin3D_Gen2(IO.ComfyNode):
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@@ -487,18 +542,22 @@ class Rodin3D_Gen2(IO.ComfyNode):
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality_override = get_quality_mode(Polygon_count)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
task_uuid, subscription_key = await create_generate_task(
cls,
images=m_images,
seed=Seed,
material=Material_Type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
ta_pose=TAPose,
TAPose=TAPose,
auth_kwargs=auth,
)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
await poll_for_task_status(subscription_key, auth_kwargs=auth)
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
model = await download_files(download_list, task_uuid)
return IO.NodeOutput(model)

View File

@@ -200,7 +200,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
) -> IO.NodeOutput:
validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
download_urls = await upload_images_to_comfyapi(
cls,
@@ -290,7 +290,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
) -> IO.NodeOutput:
validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
download_urls = await upload_images_to_comfyapi(
cls,
@@ -390,8 +390,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_dimensions(end_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
validate_image_aspect_ratio(end_frame, (1, 2), (2, 1))
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
download_urls = await upload_images_to_comfyapi(
@@ -475,7 +475,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
reference_images = None
if reference_image is not None:
validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
validate_image_aspect_ratio(reference_image, (1, 2), (2, 1))
validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
download_urls = await upload_images_to_comfyapi(
cls,
reference_image,

View File

@@ -20,6 +20,13 @@ from comfy_api_nodes.apis.stability_api import (
StabilityAudioInpaintRequest,
StabilityAudioResponse,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.util import (
validate_audio_duration,
validate_string,
@@ -27,9 +34,6 @@ from comfy_api_nodes.util import (
bytesio_to_image_tensor,
tensor_to_bytesio,
audio_bytes_to_audio_input,
sync_op,
poll_op,
ApiEndpoint,
)
import torch
@@ -157,11 +161,19 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
"image": image_binary
}
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"),
response_model=StabilityStableUltraResponse,
data=StabilityStableUltraRequest(
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/generate/ultra",
method=HttpMethod.POST,
request_model=StabilityStableUltraRequest,
response_model=StabilityStableUltraResponse,
),
request=StabilityStableUltraRequest(
prompt=prompt,
negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio,
@@ -171,7 +183,9 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
),
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
@@ -299,11 +313,19 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
"image": image_binary
}
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"),
response_model=StabilityStableUltraResponse,
data=StabilityStable3_5Request(
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/generate/sd3",
method=HttpMethod.POST,
request_model=StabilityStable3_5Request,
response_model=StabilityStableUltraResponse,
),
request=StabilityStable3_5Request(
prompt=prompt,
negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio,
@@ -316,7 +338,9 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
),
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
@@ -403,11 +427,19 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
"image": image_binary
}
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"),
response_model=StabilityStableUltraResponse,
data=StabilityUpscaleConservativeRequest(
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/conservative",
method=HttpMethod.POST,
request_model=StabilityUpscaleConservativeRequest,
response_model=StabilityStableUltraResponse,
),
request=StabilityUpscaleConservativeRequest(
prompt=prompt,
negative_prompt=negative_prompt,
creativity=round(creativity,2),
@@ -415,7 +447,9 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
),
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
@@ -510,11 +544,19 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
"image": image_binary
}
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"),
response_model=StabilityAsyncResponse,
data=StabilityUpscaleCreativeRequest(
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/creative",
method=HttpMethod.POST,
request_model=StabilityUpscaleCreativeRequest,
response_model=StabilityAsyncResponse,
),
request=StabilityUpscaleCreativeRequest(
prompt=prompt,
negative_prompt=negative_prompt,
creativity=round(creativity,2),
@@ -523,15 +565,25 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
),
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
response_api = await operation.execute()
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"),
response_model=StabilityResultsGetResponse,
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/stability/v2beta/results/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=StabilityResultsGetResponse,
),
poll_interval=3,
completed_statuses=[StabilityPollStatus.finished],
failed_statuses=[StabilityPollStatus.failed],
status_extractor=lambda x: get_async_dummy_status(x),
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
)
response_poll: StabilityResultsGetResponse = await operation.execute()
if response_poll.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
@@ -576,13 +628,24 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
"image": image_binary
}
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"),
response_model=StabilityStableUltraResponse,
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/fast",
method=HttpMethod.POST,
request_model=EmptyRequest,
response_model=StabilityStableUltraResponse,
),
request=EmptyRequest(),
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
@@ -654,13 +717,21 @@ class StabilityTextToAudio(IO.ComfyNode):
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
validate_string(prompt, max_length=10000)
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"),
response_model=StabilityAudioResponse,
data=payload,
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio",
method=HttpMethod.POST,
request_model=StabilityTextToAudioRequest,
response_model=StabilityAudioResponse,
),
request=payload,
content_type="multipart/form-data",
auth_kwargs= {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
response_api = await operation.execute()
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
@@ -743,14 +814,22 @@ class StabilityAudioToAudio(IO.ComfyNode):
payload = StabilityAudioToAudioRequest(
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
)
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"),
response_model=StabilityAudioResponse,
data=payload,
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio",
method=HttpMethod.POST,
request_model=StabilityAudioToAudioRequest,
response_model=StabilityAudioResponse,
),
request=payload,
content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)},
auth_kwargs= {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
response_api = await operation.execute()
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
@@ -856,14 +935,22 @@ class StabilityAudioInpaint(IO.ComfyNode):
mask_start=mask_start,
mask_end=mask_end,
)
response_api = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"),
response_model=StabilityAudioResponse,
data=payload,
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint",
method=HttpMethod.POST,
request_model=StabilityAudioInpaintRequest,
response_model=StabilityAudioResponse,
),
request=payload,
content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)},
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
response_api = await operation.execute()
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))

View File

@@ -14,9 +14,9 @@ from comfy_api_nodes.util import (
poll_op,
sync_op,
upload_images_to_comfyapi,
validate_image_aspect_ratio,
validate_aspect_ratio_closeness,
validate_image_aspect_ratio_range,
validate_image_dimensions,
validate_images_aspect_ratio_closeness,
)
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
@@ -114,7 +114,7 @@ async def execute_task(
cls,
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id),
response_model=TaskStatusResponse,
status_extractor=lambda r: r.state,
status_extractor=lambda r: r.state.value,
estimated_duration=estimated_duration,
)
@@ -307,7 +307,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
) -> IO.NodeOutput:
if get_number_of_images(image) > 1:
raise ValueError("Only one input image is allowed.")
validate_image_aspect_ratio(image, (1, 4), (4, 1))
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
payload = TaskCreationRequest(
model_name=model,
prompt=prompt,
@@ -423,7 +423,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
if a > 7:
raise ValueError("Too many images, maximum allowed is 7.")
for image in images:
validate_image_aspect_ratio(image, (1, 4), (4, 1))
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
validate_image_dimensions(image, min_width=128, min_height=128)
payload = TaskCreationRequest(
model_name=model,
@@ -533,7 +533,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
resolution: str,
movement_amplitude: str,
) -> IO.NodeOutput:
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
payload = TaskCreationRequest(
model_name=model,
prompt=prompt,

View File

@@ -14,12 +14,9 @@ from .conversions import (
downscale_image_tensor,
image_tensor_pair_to_batch,
pil_to_bytesio,
resize_mask_to_image,
tensor_to_base64_string,
tensor_to_bytesio,
tensor_to_pil,
text_filepath_to_base64_string,
text_filepath_to_data_uri,
trim_video,
video_to_base64_string,
)
@@ -37,12 +34,12 @@ from .upload_helpers import (
)
from .validation_utils import (
get_number_of_images,
validate_aspect_ratio_string,
validate_aspect_ratio_closeness,
validate_audio_duration,
validate_container_format_is_mp4,
validate_image_aspect_ratio,
validate_image_aspect_ratio_range,
validate_image_dimensions,
validate_images_aspect_ratio_closeness,
validate_string,
validate_video_dimensions,
validate_video_duration,
@@ -73,22 +70,19 @@ __all__ = [
"downscale_image_tensor",
"image_tensor_pair_to_batch",
"pil_to_bytesio",
"resize_mask_to_image",
"tensor_to_base64_string",
"tensor_to_bytesio",
"tensor_to_pil",
"text_filepath_to_base64_string",
"text_filepath_to_data_uri",
"trim_video",
"video_to_base64_string",
# Validation utilities
"get_number_of_images",
"validate_aspect_ratio_string",
"validate_aspect_ratio_closeness",
"validate_audio_duration",
"validate_container_format_is_mp4",
"validate_image_aspect_ratio",
"validate_image_aspect_ratio_range",
"validate_image_dimensions",
"validate_images_aspect_ratio_closeness",
"validate_string",
"validate_video_dimensions",
"validate_video_duration",

View File

@@ -16,9 +16,9 @@ from pydantic import BaseModel
from comfy import utils
from comfy_api.latest import IO
from comfy_api_nodes.apis import request_logger
from server import PromptServer
from . import request_logger
from ._helpers import (
default_base_url,
get_auth_header,
@@ -77,8 +77,8 @@ class _PollUIState:
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done"]
FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"]
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"]
FAILED_STATUSES = ["cancelled", "canceled", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
@@ -589,7 +589,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
payload_headers = {"Accept": "*/*"}
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
payload_headers.update(get_auth_header(cfg.node_cls))
if cfg.endpoint.headers:

View File

@@ -1,7 +1,6 @@
import base64
import logging
import math
import mimetypes
import uuid
from io import BytesIO
from typing import Optional
@@ -13,7 +12,7 @@ from PIL import Image
from comfy.utils import common_upscale
from comfy_api.latest import Input, InputImpl
from comfy_api.util import VideoCodec, VideoContainer
from comfy_api.util import VideoContainer, VideoCodec
from ._helpers import mimetype_to_extension
@@ -431,40 +430,3 @@ def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict:
wav = torch.cat(frames, dim=1) # [C, T]
wav = _f32_pcm(wav)
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
def resize_mask_to_image(
mask: torch.Tensor,
image: torch.Tensor,
upscale_method="nearest-exact",
crop="disabled",
allow_gradient=True,
add_channel_dim=False,
):
"""Resize mask to be the same dimensions as an image, while maintaining proper format for API calls."""
_, height, width, _ = image.shape
mask = mask.unsqueeze(-1)
mask = mask.movedim(-1, 1)
mask = common_upscale(mask, width=width, height=height, upscale_method=upscale_method, crop=crop)
mask = mask.movedim(1, -1)
if not add_channel_dim:
mask = mask.squeeze(-1)
if not allow_gradient:
mask = (mask > 0.5).float()
return mask
def text_filepath_to_base64_string(filepath: str) -> str:
"""Converts a text file to a base64 string."""
with open(filepath, "rb") as f:
file_content = f.read()
return base64.b64encode(file_content).decode("utf-8")
def text_filepath_to_data_uri(filepath: str) -> str:
"""Converts a text file to a data URI."""
base64_string = text_filepath_to_base64_string(filepath)
mime_type, _ = mimetypes.guess_type(filepath)
if mime_type is None:
mime_type = "application/octet-stream"
return f"data:{mime_type};base64,{base64_string}"

View File

@@ -12,8 +12,8 @@ from aiohttp.client_exceptions import ClientError, ContentTypeError
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import IO as COMFY_IO
from comfy_api_nodes.apis import request_logger
from . import request_logger
from ._helpers import (
default_base_url,
get_auth_header,
@@ -232,12 +232,11 @@ async def download_url_to_video_output(
video_url: str,
*,
timeout: float = None,
max_retries: int = 5,
cls: type[COMFY_IO.ComfyNode] = None,
) -> VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output."""
result = BytesIO()
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
await download_url_to_bytesio(video_url, result, timeout=timeout, cls=cls)
return VideoFromFile(result)

View File

@@ -13,8 +13,8 @@ from pydantic import BaseModel, Field
from comfy_api.latest import IO, Input
from comfy_api.util import VideoCodec, VideoContainer
from comfy_api_nodes.apis import request_logger
from . import request_logger
from ._helpers import is_processing_interrupted, sleep_with_interrupt
from .client import (
ApiEndpoint,

View File

@@ -37,62 +37,63 @@ def validate_image_dimensions(
def validate_image_aspect_ratio(
image: torch.Tensor,
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
min_aspect_ratio: Optional[float] = None,
max_aspect_ratio: Optional[float] = None,
):
width, height = get_image_dimensions(image)
aspect_ratio = width / height
if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
raise ValueError(f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}")
if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
raise ValueError(f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}")
def validate_image_aspect_ratio_range(
image: torch.Tensor,
min_ratio: tuple[float, float], # e.g. (1, 4)
max_ratio: tuple[float, float], # e.g. (4, 1)
*,
strict: bool = True, # True -> (min, max); False -> [min, max]
) -> float:
"""Validates that image aspect ratio is within min and max. If a bound is None, that side is not checked."""
a1, b1 = min_ratio
a2, b2 = max_ratio
if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0:
raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).")
lo, hi = (a1 / b1), (a2 / b2)
if lo > hi:
lo, hi = hi, lo
a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text
w, h = get_image_dimensions(image)
if w <= 0 or h <= 0:
raise ValueError(f"Invalid image dimensions: {w}x{h}")
ar = w / h
_assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
ok = (lo < ar < hi) if strict else (lo <= ar <= hi)
if not ok:
op = "<" if strict else ""
raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}")
return ar
def validate_images_aspect_ratio_closeness(
first_image: torch.Tensor,
second_image: torch.Tensor,
min_rel: float, # e.g. 0.8
max_rel: float, # e.g. 1.25
def validate_aspect_ratio_closeness(
start_img,
end_img,
min_rel: float,
max_rel: float,
*,
strict: bool = False, # True -> (min, max); False -> [min, max]
) -> float:
"""
Validates that the two images' aspect ratios are 'close'.
The closeness factor is C = max(ar1, ar2) / min(ar1, ar2) (C >= 1).
We require C <= limit, where limit = max(max_rel, 1.0 / min_rel).
Returns the computed closeness factor C.
"""
w1, h1 = get_image_dimensions(first_image)
w2, h2 = get_image_dimensions(second_image)
strict: bool = False, # True => exclusive, False => inclusive
) -> None:
w1, h1 = get_image_dimensions(start_img)
w2, h2 = get_image_dimensions(end_img)
if min(w1, h1, w2, h2) <= 0:
raise ValueError("Invalid image dimensions")
ar1 = w1 / h1
ar2 = w2 / h2
# Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1)
closeness = max(ar1, ar2) / min(ar1, ar2)
limit = max(max_rel, 1.0 / min_rel)
limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25
if (closeness >= limit) if strict else (closeness > limit):
raise ValueError(
f"Aspect ratios must be close: ar1/ar2={ar1/ar2:.2g}, "
f"allowed range {min_rel}{max_rel} (limit {limit:.2g})."
)
return closeness
def validate_aspect_ratio_string(
aspect_ratio: str,
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
*,
strict: bool = False, # True -> (min, max); False -> [min, max]
) -> float:
"""Parses 'X:Y' and validates it against optional bounds. Returns the numeric ratio."""
ar = _parse_aspect_ratio_string(aspect_ratio)
_assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
return ar
raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}{max_rel}.")
def validate_video_dimensions(
@@ -182,49 +183,3 @@ def validate_container_format_is_mp4(video: VideoInput) -> None:
container_format = video.get_container_format()
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
def _ratio_from_tuple(r: tuple[float, float]) -> float:
a, b = r
if a <= 0 or b <= 0:
raise ValueError(f"Ratios must be positive, got {a}:{b}.")
return a / b
def _assert_ratio_bounds(
ar: float,
*,
min_ratio: Optional[tuple[float, float]] = None,
max_ratio: Optional[tuple[float, float]] = None,
strict: bool = True,
) -> None:
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
lo = _ratio_from_tuple(min_ratio) if min_ratio is not None else None
hi = _ratio_from_tuple(max_ratio) if max_ratio is not None else None
if lo is not None and hi is not None and lo > hi:
lo, hi = hi, lo # normalize order if caller swapped them
if lo is not None:
if (ar <= lo) if strict else (ar < lo):
op = "<" if strict else ""
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {lo:.2g}.")
if hi is not None:
if (ar >= hi) if strict else (ar > hi):
op = "<" if strict else ""
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {hi:.2g}.")
def _parse_aspect_ratio_string(ar_str: str) -> float:
"""Parse 'X:Y' with integer parts into a positive float ratio X/Y."""
parts = ar_str.split(":")
if len(parts) != 2:
raise ValueError(f"Aspect ratio must be 'X:Y' (e.g., 16:9), got '{ar_str}'.")
try:
a = int(parts[0].strip())
b = int(parts[1].strip())
except ValueError as exc:
raise ValueError(f"Aspect ratio must contain integers separated by ':', got '{ar_str}'.") from exc
if a <= 0 or b <= 0:
raise ValueError(f"Aspect ratio parts must be positive integers, got {a}:{b}.")
return a / b

View File

@@ -1,9 +1,4 @@
import bisect
import gc
import itertools
import psutil
import time
import torch
from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt
from abc import ABC, abstractmethod
@@ -53,7 +48,7 @@ class Unhashable:
def to_hashable(obj):
# So that we don't infinitely recurse since frozenset and tuples
# are Sequences.
if isinstance(obj, (int, float, str, bool, bytes, type(None))):
if isinstance(obj, (int, float, str, bool, type(None))):
return obj
elif isinstance(obj, Mapping):
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
@@ -193,9 +188,6 @@ class BasicCache:
self._clean_cache()
self._clean_subcaches()
def poll(self, **kwargs):
pass
def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
@@ -284,9 +276,6 @@ class NullCache:
def clean_unused(self):
pass
def poll(self, **kwargs):
pass
def get(self, node_id):
return None
@@ -347,77 +336,3 @@ class LRUCache(BasicCache):
self._mark_used(child_id)
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
RAM_CACHE_HYSTERESIS = 1.1
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
#Exponential bias towards evicting older workflows so garbage will be taken out
#in constantly changing setups.
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
class RAMPressureCache(LRUCache):
def __init__(self, key_class):
super().__init__(key_class, 0)
self.timestamps = {}
def clean_unused(self):
self._clean_subcaches()
def set(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
super().set(node_id, value)
def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return super().get(node_id)
def poll(self, ram_headroom):
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)
if _ram_gb() > ram_headroom:
return
gc.collect()
if _ram_gb() > ram_headroom:
return
clean_list = []
for key, (outputs, _), in self.cache.items():
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
def scan_list_for_ram_usage(outputs):
nonlocal ram_usage
if outputs is None:
return
for output in outputs:
if isinstance(output, list):
scan_list_for_ram_usage(output)
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
#score Tensors at a 50% discount for RAM usage as they are likely to
#be high value intermediates
ram_usage += (output.numel() * output.element_size()) * 0.5
elif hasattr(output, "get_ram_usage"):
ram_usage += output.get_ram_usage()
scan_list_for_ram_usage(outputs)
oom_score *= ram_usage
#In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
_, _, key = clean_list.pop()
del self.cache[key]
gc.collect()

View File

@@ -209,15 +209,10 @@ class ExecutionList(TopologicalSort):
self.execution_cache_listeners[from_node_id] = set()
self.execution_cache_listeners[from_node_id].add(to_node_id)
def get_cache(self, from_node_id, to_node_id):
def get_output_cache(self, from_node_id, to_node_id):
if not to_node_id in self.execution_cache:
return None
value = self.execution_cache[to_node_id].get(from_node_id)
if value is None:
return None
#Write back to the main cache on touch.
self.output_cache.set(from_node_id, value)
return value
return self.execution_cache[to_node_id].get(from_node_id)
def cache_update(self, node_id, value):
if node_id in self.execution_cache_listeners:

View File

@@ -0,0 +1,43 @@
from __future__ import annotations
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class FlipFlop(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="FlipFlopNew",
display_name="FlipFlop (New)",
category="_for_testing",
inputs=[
io.Model.Input(id="model"),
io.Float.Input(id="block_percentage", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Model.Output()
],
description="Apply FlipFlop transformation to model using setup_flipflop_holders method"
)
@classmethod
def execute(cls, model: io.Model.Type, block_percentage: float) -> io.NodeOutput:
# NOTE: this is just a hacky prototype still, this would not be exposed as a node.
# At the moment, this modifies the underlying model with no way to 'unpatch' it.
model = model.clone()
if not hasattr(model.model.diffusion_model, "setup_flipflop_holders"):
raise ValueError("Model does not have flipflop holders; FlipFlop not supported")
model.model.diffusion_model.setup_flipflop_holders(block_percentage)
return io.NodeOutput(model)
class FlipFlopExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
FlipFlop,
]
async def comfy_entrypoint() -> FlipFlopExtension:
return FlipFlopExtension()

View File

@@ -2,9 +2,6 @@ import comfy.utils
import folder_paths
import torch
import logging
from comfy_api.latest import IO, ComfyExtension
from typing_extensions import override
def load_hypernetwork_patch(path, strength):
sd = comfy.utils.load_torch_file(path, safe_load=True)
@@ -97,42 +94,27 @@ def load_hypernetwork_patch(path, strength):
return hypernetwork_patch(out, strength)
class HypernetworkLoader(IO.ComfyNode):
class HypernetworkLoader:
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="HypernetworkLoader",
category="loaders",
inputs=[
IO.Model.Input("model"),
IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")),
IO.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01),
],
outputs=[
IO.Model.Output(),
],
)
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_hypernetwork"
@classmethod
def execute(cls, model, hypernetwork_name, strength) -> IO.NodeOutput:
CATEGORY = "loaders"
def load_hypernetwork(self, model, hypernetwork_name, strength):
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
model_hypernetwork = model.clone()
patch = load_hypernetwork_patch(hypernetwork_path, strength)
if patch is not None:
model_hypernetwork.set_model_attn1_patch(patch)
model_hypernetwork.set_model_attn2_patch(patch)
return IO.NodeOutput(model_hypernetwork)
return (model_hypernetwork,)
load_hypernetwork = execute # TODO: remove
class HyperNetworkExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
HypernetworkLoader,
]
async def comfy_entrypoint() -> HyperNetworkExtension:
return HyperNetworkExtension()
NODE_CLASS_MAPPINGS = {
"HypernetworkLoader": HypernetworkLoader
}

View File

@@ -1,47 +0,0 @@
from comfy_api.latest import ComfyExtension, io
from typing_extensions import override
class ScaleROPE(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ScaleROPE",
category="advanced/model_patches",
description="Scale and shift the ROPE of the model.",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.Float.Input("scale_x", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_x", default=0.0, min=-256.0, max=256.0, step=0.1),
io.Float.Input("scale_y", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_y", default=0.0, min=-256.0, max=256.0, step=0.1),
io.Float.Input("scale_t", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_t", default=0.0, min=-256.0, max=256.0, step=0.1),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t) -> io.NodeOutput:
m = model.clone()
m.set_model_rope_options(scale_x, shift_x, scale_y, shift_y, scale_t, shift_t)
return io.NodeOutput(m)
class RopeExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ScaleROPE
]
async def comfy_entrypoint() -> RopeExtension:
return RopeExtension()

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.68"
__version__ = "0.3.67"

View File

@@ -21,7 +21,6 @@ from comfy_execution.caching import (
NullCache,
HierarchicalCache,
LRUCache,
RAMPressureCache,
)
from comfy_execution.graph import (
DynamicPrompt,
@@ -89,56 +88,49 @@ class IsChangedCache:
return self.is_changed[node_id]
class CacheEntry(NamedTuple):
ui: dict
outputs: list
class CacheType(Enum):
CLASSIC = 0
LRU = 1
NONE = 2
RAM_PRESSURE = 3
class CacheSet:
def __init__(self, cache_type=None, cache_args={}):
def __init__(self, cache_type=None, cache_size=None):
if cache_type == CacheType.NONE:
self.init_null_cache()
logging.info("Disabling intermediate node cache.")
elif cache_type == CacheType.RAM_PRESSURE:
cache_ram = cache_args.get("ram", 16.0)
self.init_ram_cache(cache_ram)
logging.info("Using RAM pressure cache.")
elif cache_type == CacheType.LRU:
cache_size = cache_args.get("lru", 0)
if cache_size is None:
cache_size = 0
self.init_lru_cache(cache_size)
logging.info("Using LRU cache")
else:
self.init_classic_cache()
self.all = [self.outputs, self.objects]
self.all = [self.outputs, self.ui, self.objects]
# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
def init_ram_cache(self, min_headroom):
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
def init_null_cache(self):
self.outputs = NullCache()
#The UI cache is expected to be iterable at the end of each workflow
#so it must cache at least a full workflow. Use Heirachical
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = NullCache()
def recursive_debug_dump(self):
result = {
"outputs": self.outputs.recursive_debug_dump(),
"ui": self.ui.recursive_debug_dump(),
}
return result
@@ -165,14 +157,14 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
if execution_list is None:
mark_missing()
continue # This might be a lazily-evaluated input
cached = execution_list.get_cache(input_unique_id, unique_id)
if cached is None or cached.outputs is None:
cached_output = execution_list.get_output_cache(input_unique_id, unique_id)
if cached_output is None:
mark_missing()
continue
if output_index >= len(cached.outputs):
if output_index >= len(cached_output):
mark_missing()
continue
obj = cached.outputs[output_index]
obj = cached_output[output_index]
input_data_all[x] = obj
elif input_category is not None:
input_data_all[x] = [input_data]
@@ -401,7 +393,7 @@ def format_value(x):
else:
return str(x)
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes):
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
@@ -409,15 +401,12 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
cached = caches.outputs.get(unique_id)
if cached is not None:
if caches.outputs.get(unique_id) is not None:
if server.client_id is not None:
cached_ui = cached.ui or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
if cached.ui is not None:
ui_outputs[unique_id] = cached.ui
cached_output = caches.ui.get(unique_id) or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
get_progress_state().finish_progress(unique_id)
execution_list.cache_update(unique_id, cached)
execution_list.cache_update(unique_id, caches.outputs.get(unique_id))
return (ExecutionResult.SUCCESS, None, None)
input_data_all = None
@@ -447,8 +436,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
for r in result:
if is_link(r):
source_node, source_output = r[0], r[1]
node_cached = execution_list.get_cache(source_node, unique_id)
for o in node_cached.outputs[source_output]:
node_output = execution_list.get_output_cache(source_node, unique_id)[source_output]
for o in node_output:
resolved_output.append(o)
else:
@@ -518,7 +507,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0:
ui_outputs[unique_id] = {
caches.ui.set(unique_id, {
"meta": {
"node_id": unique_id,
"display_node": display_node_id,
@@ -526,7 +515,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
"real_node_id": real_node_id,
},
"output": output_ui
}
})
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
if has_subgraph:
@@ -565,9 +554,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.PENDING, None, None)
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
execution_list.cache_update(unique_id, cache_entry)
caches.outputs.set(unique_id, cache_entry)
caches.outputs.set(unique_id, output_data)
execution_list.cache_update(unique_id, output_data)
except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted")
@@ -612,14 +600,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
return (ExecutionResult.SUCCESS, None, None)
class PromptExecutor:
def __init__(self, server, cache_type=False, cache_args=None):
self.cache_args = cache_args
def __init__(self, server, cache_type=False, cache_size=None):
self.cache_size = cache_size
self.cache_type = cache_type
self.server = server
self.reset()
def reset(self):
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
self.status_messages = []
self.success = True
@@ -694,7 +682,6 @@ class PromptExecutor:
broadcast=False)
pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
ui_node_outputs = {}
executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
@@ -708,7 +695,7 @@ class PromptExecutor:
break
assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
@@ -717,16 +704,18 @@ class PromptExecutor:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
ui_outputs = {}
meta_outputs = {}
for node_id, ui_info in ui_node_outputs.items():
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
all_node_ids = self.caches.ui.all_node_ids()
for node_id in all_node_ids:
ui_info = self.caches.ui.get(node_id)
if ui_info is not None:
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = {
"outputs": ui_outputs,
"meta": meta_outputs,

View File

@@ -172,12 +172,10 @@ def prompt_worker(q, server_instance):
cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0:
cache_type = execution.CacheType.LRU
elif args.cache_ram > 0:
cache_type = execution.CacheType.RAM_PRESSURE
elif args.cache_none:
cache_type = execution.CacheType.NONE
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0

View File

@@ -2329,7 +2329,7 @@ async def init_builtin_extra_nodes():
"nodes_model_patch.py",
"nodes_easycache.py",
"nodes_audio_encoder.py",
"nodes_rope.py",
"nodes_flipflop.py",
]
import_failed = []

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.68"
version = "0.3.67"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@@ -1,6 +1,6 @@
comfyui-frontend-package==1.28.8
comfyui-workflow-templates==0.3.1
comfyui-embedded-docs==0.3.1
comfyui-workflow-templates==0.2.4
comfyui-embedded-docs==0.3.0
torch
torchsde
torchvision

View File

@@ -29,7 +29,7 @@ import comfy.model_management
from comfy_api import feature_flags
import node_helpers
from comfyui_version import __version__
from app.frontend_management import FrontendManager, parse_version
from app.frontend_management import FrontendManager
from comfy_api.internal import _ComfyNodeInternal
from app.user_manager import UserManager
@@ -847,31 +847,11 @@ class PromptServer():
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
self.app.add_routes([web.static('/extensions/' + name, dir)])
installed_templates_version = FrontendManager.get_installed_templates_version()
use_legacy_templates = True
if installed_templates_version:
try:
use_legacy_templates = (
parse_version(installed_templates_version)
< parse_version("0.3.0")
)
except Exception as exc:
logging.warning(
"Unable to parse templates version '%s': %s",
installed_templates_version,
exc,
)
if use_legacy_templates:
workflow_templates_path = FrontendManager.legacy_templates_path()
if workflow_templates_path:
self.app.add_routes([
web.static('/templates', workflow_templates_path)
])
else:
handler = FrontendManager.template_asset_handler()
if handler:
self.app.router.add_get("/templates/{path:.*}", handler)
workflow_templates_path = FrontendManager.templates_path()
if workflow_templates_path:
self.app.add_routes([
web.static('/templates', workflow_templates_path)
])
# Serve embedded documentation from the package
embedded_docs_path = FrontendManager.embedded_docs_path()

View File

@@ -14,7 +14,7 @@ if not has_gpu():
args.cpu = True
from comfy import ops
from comfy.quant_ops import QuantizedTensor
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
class SimpleModel(torch.nn.Module):
@@ -104,14 +104,14 @@ class TestMixedPrecisionOps(unittest.TestCase):
# Verify weights are wrapped in QuantizedTensor
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout")
self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout)
# Layer 2 should NOT be quantized
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
# Layer 3 should be quantized
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout")
self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout)
# Verify scales were loaded
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
@@ -155,7 +155,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
# Verify layer1.weight is a QuantizedTensor with scale preserved
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout")
self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout)
# Verify non-quantized layers are standard tensors
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)

View File

@@ -25,14 +25,14 @@ class TestQuantizedTensor(unittest.TestCase):
scale = torch.tensor(2.0)
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.shape, (256, 128))
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt._layout_params['scale'], scale)
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
self.assertEqual(qt._layout_type, TensorCoreFP8Layout)
def test_dequantize(self):
"""Test explicit dequantization"""
@@ -41,7 +41,7 @@ class TestQuantizedTensor(unittest.TestCase):
scale = torch.tensor(3.0)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
dequantized = qt.dequantize()
self.assertEqual(dequantized.dtype, torch.float32)
@@ -54,7 +54,7 @@ class TestQuantizedTensor(unittest.TestCase):
qt = QuantizedTensor.from_float(
float_tensor,
"TensorCoreFP8Layout",
TensorCoreFP8Layout,
scale=scale,
dtype=torch.float8_e4m3fn
)
@@ -77,28 +77,28 @@ class TestGenericUtilities(unittest.TestCase):
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
# Detach should return a new QuantizedTensor
qt_detached = qt.detach()
self.assertIsInstance(qt_detached, QuantizedTensor)
self.assertEqual(qt_detached.shape, qt.shape)
self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout")
self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout)
def test_clone(self):
"""Test clone operation on quantized tensor"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
# Clone should return a new QuantizedTensor
qt_cloned = qt.clone()
self.assertIsInstance(qt_cloned, QuantizedTensor)
self.assertEqual(qt_cloned.shape, qt.shape)
self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout")
self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout)
# Verify it's a deep copy
self.assertIsNot(qt_cloned._qdata, qt._qdata)
@@ -109,7 +109,7 @@ class TestGenericUtilities(unittest.TestCase):
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
# Moving to same device should work (CPU to CPU)
qt_cpu = qt.to('cpu')
@@ -169,7 +169,7 @@ class TestFallbackMechanism(unittest.TestCase):
scale = torch.tensor(1.0)
a_q = QuantizedTensor.from_float(
a_fp32,
"TensorCoreFP8Layout",
TensorCoreFP8Layout,
scale=scale,
dtype=torch.float8_e4m3fn
)