mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-14 04:00:03 +00:00
Compare commits
66 Commits
v0.9.1
...
feature/fr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b40a25a31d | ||
|
|
7484c9c237 | ||
|
|
8adafb4d65 | ||
|
|
3c0365f6d6 | ||
|
|
9c7d5f1fdd | ||
|
|
2c37119ff8 | ||
|
|
72f6be1690 | ||
|
|
16b9aabd52 | ||
|
|
245f6139b6 | ||
|
|
3365ad18a5 | ||
|
|
f09904720d | ||
|
|
191834c633 | ||
|
|
abe2ec26a6 | ||
|
|
5faf2e3cfd | ||
|
|
bdeac8897e | ||
|
|
451af70154 | ||
|
|
0fc15700be | ||
|
|
e755268e7b | ||
|
|
c4a14df9a3 | ||
|
|
965d0ed509 | ||
|
|
ddc541ffda | ||
|
|
8ccc0c94fa | ||
|
|
4edb87aa50 | ||
|
|
0fc3b6e3a6 | ||
|
|
2108167f9f | ||
|
|
9d273d3ab1 | ||
|
|
70c91b8248 | ||
|
|
0da5a0fe58 | ||
|
|
e0eacb0688 | ||
|
|
7458e20465 | ||
|
|
b931b37e30 | ||
|
|
866a4619db | ||
|
|
1a72bf2046 | ||
|
|
034fac7054 | ||
|
|
a498556d0d | ||
|
|
f7ca41ff62 | ||
|
|
ac26065e61 | ||
|
|
190c4416cc | ||
|
|
0fd10ffa09 | ||
|
|
00c775950a | ||
|
|
7ac999bf30 | ||
|
|
0c6b36c6ac | ||
|
|
9125613b53 | ||
|
|
732b707397 | ||
|
|
4c816d5c69 | ||
|
|
6125b3a5e7 | ||
|
|
12918a5f78 | ||
|
|
8f40b43e02 | ||
|
|
3b832231bb | ||
|
|
be518db5a7 | ||
|
|
80441eb15e | ||
|
|
07f2462eae | ||
|
|
d150440466 | ||
|
|
6165c38cb5 | ||
|
|
712cca36a1 | ||
|
|
ac4d8ea9b3 | ||
|
|
c9196f355e | ||
|
|
7eb959ce93 | ||
|
|
469dd9c16a | ||
|
|
eff2b9d412 | ||
|
|
15b312de7a | ||
|
|
1419047fdb | ||
|
|
79f6bb5e4f | ||
|
|
e4b4fb3479 | ||
|
|
d9dc02a7d6 | ||
|
|
c543ad81c3 |
2
.github/workflows/test-launch.yml
vendored
2
.github/workflows/test-launch.yml
vendored
@@ -13,7 +13,7 @@ jobs:
|
||||
- name: Checkout ComfyUI
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: "comfyanonymous/ComfyUI"
|
||||
repository: "Comfy-Org/ComfyUI"
|
||||
path: "ComfyUI"
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
|
||||
59
.github/workflows/update-ci-container.yml
vendored
Normal file
59
.github/workflows/update-ci-container.yml
vendored
Normal file
@@ -0,0 +1,59 @@
|
||||
name: "CI: Update CI Container"
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'ComfyUI version (e.g., v0.7.0)'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
update-ci-container:
|
||||
runs-on: ubuntu-latest
|
||||
# Skip pre-releases unless manually triggered
|
||||
if: github.event_name == 'workflow_dispatch' || !github.event.release.prerelease
|
||||
steps:
|
||||
- name: Get version
|
||||
id: version
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" = "release" ]; then
|
||||
VERSION="${{ github.event.release.tag_name }}"
|
||||
else
|
||||
VERSION="${{ inputs.version }}"
|
||||
fi
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Checkout comfyui-ci-container
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: comfy-org/comfyui-ci-container
|
||||
token: ${{ secrets.CI_CONTAINER_PAT }}
|
||||
|
||||
- name: Check current version
|
||||
id: current
|
||||
run: |
|
||||
CURRENT=$(grep -oP 'ARG COMFYUI_VERSION=\K.*' Dockerfile || echo "unknown")
|
||||
echo "current_version=$CURRENT" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Update Dockerfile
|
||||
run: |
|
||||
VERSION="${{ steps.version.outputs.version }}"
|
||||
sed -i "s/^ARG COMFYUI_VERSION=.*/ARG COMFYUI_VERSION=${VERSION}/" Dockerfile
|
||||
|
||||
- name: Create Pull Request
|
||||
id: create-pr
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
token: ${{ secrets.CI_CONTAINER_PAT }}
|
||||
branch: automation/comfyui-${{ steps.version.outputs.version }}
|
||||
title: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}"
|
||||
body: |
|
||||
Updates ComfyUI version from `${{ steps.current.outputs.current_version }}` to `${{ steps.version.outputs.version }}`
|
||||
|
||||
**Triggered by:** ${{ github.event_name == 'release' && format('[Release {0}]({1})', github.event.release.tag_name, github.event.release.html_url) || 'Manual workflow dispatch' }}
|
||||
|
||||
labels: automation
|
||||
commit-message: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}"
|
||||
@@ -108,7 +108,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||
- Works fully offline: core will never download anything unless you want to.
|
||||
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
|
||||
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview) disable with: `--disable-api-nodes`
|
||||
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
||||
|
||||
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
@@ -212,7 +212,7 @@ Python 3.14 works but you may encounter issues with the torch compile node. The
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
|
||||
torch 2.4 and above is supported but some features and optimizations might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
|
||||
|
||||
### Instructions:
|
||||
|
||||
@@ -229,7 +229,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
||||
|
||||
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
|
||||
|
||||
@@ -240,7 +240,7 @@ These have less hardware support than the builds above but they work on windows.
|
||||
|
||||
RDNA 3 (RX 7000 series):
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-all/```
|
||||
|
||||
RDNA 3.5 (Strix halo/Ryzen AI Max+ 365):
|
||||
|
||||
|
||||
23
app/node_replace_manager.py
Normal file
23
app/node_replace_manager.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy_api.latest._node_replace import NodeReplace
|
||||
|
||||
REGISTERED_NODE_REPLACEMENTS: dict[str, list[NodeReplace]] = {}
|
||||
|
||||
def register_node_replacement(node_replace: NodeReplace):
|
||||
REGISTERED_NODE_REPLACEMENTS.setdefault(node_replace.old_node_id, []).append(node_replace)
|
||||
|
||||
def registered_as_dict():
|
||||
return {
|
||||
k: [v.as_dict() for v in v_list] for k, v_list in REGISTERED_NODE_REPLACEMENTS.items()
|
||||
}
|
||||
|
||||
class NodeReplaceManager:
|
||||
def add_routes(self, routes):
|
||||
@routes.get("/node_replacements")
|
||||
async def get_node_replacements(request):
|
||||
return web.json_response(registered_as_dict())
|
||||
@@ -10,6 +10,7 @@ import hashlib
|
||||
|
||||
class Source:
|
||||
custom_node = "custom_node"
|
||||
templates = "templates"
|
||||
|
||||
class SubgraphEntry(TypedDict):
|
||||
source: str
|
||||
@@ -38,6 +39,18 @@ class CustomNodeSubgraphEntryInfo(TypedDict):
|
||||
class SubgraphManager:
|
||||
def __init__(self):
|
||||
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
|
||||
self.cached_blueprint_subgraphs: dict[SubgraphEntry] | None = None
|
||||
|
||||
def _create_entry(self, file: str, source: str, node_pack: str) -> tuple[str, SubgraphEntry]:
|
||||
"""Create a subgraph entry from a file path. Expects normalized path (forward slashes)."""
|
||||
entry_id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
|
||||
entry: SubgraphEntry = {
|
||||
"source": source,
|
||||
"name": os.path.splitext(os.path.basename(file))[0],
|
||||
"path": file,
|
||||
"info": {"node_pack": node_pack},
|
||||
}
|
||||
return entry_id, entry
|
||||
|
||||
async def load_entry_data(self, entry: SubgraphEntry):
|
||||
with open(entry['path'], 'r') as f:
|
||||
@@ -60,53 +73,60 @@ class SubgraphManager:
|
||||
return entries
|
||||
|
||||
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
|
||||
# if not forced to reload and cached, return cache
|
||||
"""Load subgraphs from custom nodes."""
|
||||
if not force_reload and self.cached_custom_node_subgraphs is not None:
|
||||
return self.cached_custom_node_subgraphs
|
||||
# Load subgraphs from custom nodes
|
||||
subfolder = "subgraphs"
|
||||
subgraphs_dict: dict[SubgraphEntry] = {}
|
||||
|
||||
subgraphs_dict: dict[SubgraphEntry] = {}
|
||||
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
||||
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
|
||||
matched_files = glob.glob(pattern)
|
||||
for file in matched_files:
|
||||
# replace backslashes with forward slashes
|
||||
pattern = os.path.join(folder, "*/subgraphs/*.json")
|
||||
for file in glob.glob(pattern):
|
||||
file = file.replace('\\', '/')
|
||||
info: CustomNodeSubgraphEntryInfo = {
|
||||
"node_pack": "custom_nodes." + file.split('/')[-3]
|
||||
}
|
||||
source = Source.custom_node
|
||||
# hash source + path to make sure id will be as unique as possible, but
|
||||
# reproducible across backend reloads
|
||||
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
|
||||
entry: SubgraphEntry = {
|
||||
"source": Source.custom_node,
|
||||
"name": os.path.splitext(os.path.basename(file))[0],
|
||||
"path": file,
|
||||
"info": info,
|
||||
}
|
||||
subgraphs_dict[id] = entry
|
||||
node_pack = "custom_nodes." + file.split('/')[-3]
|
||||
entry_id, entry = self._create_entry(file, Source.custom_node, node_pack)
|
||||
subgraphs_dict[entry_id] = entry
|
||||
|
||||
self.cached_custom_node_subgraphs = subgraphs_dict
|
||||
return subgraphs_dict
|
||||
|
||||
async def get_custom_node_subgraph(self, id: str, loadedModules):
|
||||
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
|
||||
entry: SubgraphEntry = subgraphs.get(id, None)
|
||||
if entry is not None and entry.get('data', None) is None:
|
||||
async def get_blueprint_subgraphs(self, force_reload=False):
|
||||
"""Load subgraphs from the blueprints directory."""
|
||||
if not force_reload and self.cached_blueprint_subgraphs is not None:
|
||||
return self.cached_blueprint_subgraphs
|
||||
|
||||
subgraphs_dict: dict[SubgraphEntry] = {}
|
||||
blueprints_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'blueprints')
|
||||
|
||||
if os.path.exists(blueprints_dir):
|
||||
for file in glob.glob(os.path.join(blueprints_dir, "*.json")):
|
||||
file = file.replace('\\', '/')
|
||||
entry_id, entry = self._create_entry(file, Source.templates, "comfyui")
|
||||
subgraphs_dict[entry_id] = entry
|
||||
|
||||
self.cached_blueprint_subgraphs = subgraphs_dict
|
||||
return subgraphs_dict
|
||||
|
||||
async def get_all_subgraphs(self, loadedModules, force_reload=False):
|
||||
"""Get all subgraphs from all sources (custom nodes and blueprints)."""
|
||||
custom_node_subgraphs = await self.get_custom_node_subgraphs(loadedModules, force_reload)
|
||||
blueprint_subgraphs = await self.get_blueprint_subgraphs(force_reload)
|
||||
return {**custom_node_subgraphs, **blueprint_subgraphs}
|
||||
|
||||
async def get_subgraph(self, id: str, loadedModules):
|
||||
"""Get a specific subgraph by ID from any source."""
|
||||
entry = (await self.get_all_subgraphs(loadedModules)).get(id)
|
||||
if entry is not None and entry.get('data') is None:
|
||||
await self.load_entry_data(entry)
|
||||
return entry
|
||||
|
||||
def add_routes(self, routes, loadedModules):
|
||||
@routes.get("/global_subgraphs")
|
||||
async def get_global_subgraphs(request):
|
||||
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
|
||||
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
|
||||
# that's the reasoning for the current implementation
|
||||
subgraphs_dict = await self.get_all_subgraphs(loadedModules)
|
||||
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
|
||||
|
||||
@routes.get("/global_subgraphs/{id}")
|
||||
async def get_global_subgraph(request):
|
||||
id = request.match_info.get("id", None)
|
||||
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
|
||||
subgraph = await self.get_subgraph(id, loadedModules)
|
||||
return web.json_response(await self.sanitize_entry(subgraph))
|
||||
|
||||
0
blueprints/put_blueprints_here
Normal file
0
blueprints/put_blueprints_here
Normal file
@@ -66,6 +66,7 @@ class ClipVisionModel():
|
||||
outputs = Output()
|
||||
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
||||
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
||||
outputs["image_sizes"] = [pixel_values.shape[1:]] * pixel_values.shape[0]
|
||||
if self.return_all_hidden_states:
|
||||
all_hs = out[1].to(comfy.model_management.intermediate_device())
|
||||
outputs["penultimate_hidden_states"] = all_hs[:, -2]
|
||||
|
||||
@@ -137,10 +137,44 @@ def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor:
|
||||
return rearranged.reshape(padded_rows, padded_cols)
|
||||
|
||||
|
||||
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
|
||||
def stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator):
|
||||
F4_E2M1_MAX = 6.0
|
||||
F8_E4M3_MAX = 448.0
|
||||
|
||||
orig_shape = x.shape
|
||||
|
||||
block_size = 16
|
||||
|
||||
x = x.reshape(orig_shape[0], -1, block_size)
|
||||
scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
|
||||
x = x / (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1)
|
||||
|
||||
x = x.view(orig_shape).nan_to_num()
|
||||
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
|
||||
return data_lp, scaled_block_scales_fp8
|
||||
|
||||
|
||||
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
|
||||
def roundup(x: int, multiple: int) -> int:
|
||||
"""Round up x to the nearest multiple."""
|
||||
return ((x + multiple - 1) // multiple) * multiple
|
||||
|
||||
generator = torch.Generator(device=x.device)
|
||||
generator.manual_seed(seed)
|
||||
|
||||
# Handle padding
|
||||
if pad_16x:
|
||||
rows, cols = x.shape
|
||||
padded_rows = roundup(rows, 16)
|
||||
padded_cols = roundup(cols, 16)
|
||||
if padded_rows != rows or padded_cols != cols:
|
||||
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
|
||||
|
||||
x, blocked_scaled = stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator)
|
||||
return x, to_blocked(blocked_scaled, flatten=False)
|
||||
|
||||
|
||||
def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=0, block_size=4096 * 4096):
|
||||
def roundup(x: int, multiple: int) -> int:
|
||||
"""Round up x to the nearest multiple."""
|
||||
return ((x + multiple - 1) // multiple) * multiple
|
||||
@@ -158,28 +192,20 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
|
||||
# what we want to produce. If we pad here, we want the padded output.
|
||||
orig_shape = x.shape
|
||||
|
||||
block_size = 16
|
||||
orig_shape = list(orig_shape)
|
||||
|
||||
x = x.reshape(orig_shape[0], -1, block_size)
|
||||
max_abs = torch.amax(torch.abs(x), dim=-1)
|
||||
block_scale = max_abs / F4_E2M1_MAX
|
||||
scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype)
|
||||
scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
|
||||
total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)
|
||||
|
||||
# Handle zero blocks (from padding): avoid 0/0 NaN
|
||||
zero_scale_mask = (total_scale == 0)
|
||||
total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale)
|
||||
|
||||
x = x / total_scale_safe.unsqueeze(-1)
|
||||
output_fp4 = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 2], dtype=torch.uint8, device=x.device)
|
||||
output_block = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 16], dtype=torch.float8_e4m3fn, device=x.device)
|
||||
|
||||
generator = torch.Generator(device=x.device)
|
||||
generator.manual_seed(seed)
|
||||
|
||||
x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x)
|
||||
num_slices = max(1, (x.numel() / block_size))
|
||||
slice_size = max(1, (round(x.shape[0] / num_slices)))
|
||||
|
||||
x = x.view(orig_shape)
|
||||
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
|
||||
for i in range(0, x.shape[0], slice_size):
|
||||
fp4, block = stochastic_round_quantize_nvfp4_block(x[i: i + slice_size], per_tensor_scale, generator=generator)
|
||||
output_fp4[i:i + slice_size].copy_(fp4)
|
||||
output_block[i:i + slice_size].copy_(block)
|
||||
|
||||
blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False)
|
||||
return data_lp, blocked_scales
|
||||
return output_fp4, to_blocked(output_block, flatten=False)
|
||||
|
||||
202
comfy/ldm/anima/model.py
Normal file
202
comfy/ldm/anima/model.py
Normal file
@@ -0,0 +1,202 @@
|
||||
from comfy.ldm.cosmos.predict2 import MiniTrainDIT
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1):
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
x_embed = (x * cos) + (rotate_half(x) * sin)
|
||||
return x_embed
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, head_dim):
|
||||
super().__init__()
|
||||
self.rope_theta = 10000
|
||||
inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
inner_dim = head_dim * n_heads
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = head_dim
|
||||
self.query_dim = query_dim
|
||||
self.context_dim = context_dim
|
||||
|
||||
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
|
||||
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
|
||||
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||
|
||||
self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None):
|
||||
context = x if context is None else context
|
||||
input_shape = x.shape[:-1]
|
||||
q_shape = (*input_shape, self.n_heads, self.head_dim)
|
||||
context_shape = context.shape[:-1]
|
||||
kv_shape = (*context_shape, self.n_heads, self.head_dim)
|
||||
|
||||
query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2)
|
||||
key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2)
|
||||
value_states = self.v_proj(context).view(kv_shape).transpose(1, 2)
|
||||
|
||||
if position_embeddings is not None:
|
||||
assert position_embeddings_context is not None
|
||||
cos, sin = position_embeddings
|
||||
query_states = apply_rotary_pos_emb(query_states, cos, sin)
|
||||
cos, sin = position_embeddings_context
|
||||
key_states = apply_rotary_pos_emb(key_states, cos, sin)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
def init_weights(self):
|
||||
torch.nn.init.zeros_(self.o_proj.weight)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.use_self_attn = use_self_attn
|
||||
|
||||
if self.use_self_attn:
|
||||
self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
self.self_attn = Attention(
|
||||
query_dim=model_dim,
|
||||
context_dim=model_dim,
|
||||
n_heads=num_heads,
|
||||
head_dim=model_dim//num_heads,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
self.cross_attn = Attention(
|
||||
query_dim=model_dim,
|
||||
context_dim=source_dim,
|
||||
n_heads=num_heads,
|
||||
head_dim=model_dim//num_heads,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype),
|
||||
nn.GELU(),
|
||||
operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None):
|
||||
if self.use_self_attn:
|
||||
normed = self.norm_self_attn(x)
|
||||
attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings)
|
||||
x = x + attn_out
|
||||
|
||||
normed = self.norm_cross_attn(x)
|
||||
attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
|
||||
x = x + attn_out
|
||||
|
||||
x = x + self.mlp(self.norm_mlp(x))
|
||||
return x
|
||||
|
||||
def init_weights(self):
|
||||
torch.nn.init.zeros_(self.mlp[2].weight)
|
||||
self.cross_attn.init_weights()
|
||||
|
||||
|
||||
class LLMAdapter(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
source_dim=1024,
|
||||
target_dim=1024,
|
||||
model_dim=1024,
|
||||
num_layers=6,
|
||||
num_heads=16,
|
||||
use_self_attn=True,
|
||||
layer_norm=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype)
|
||||
if model_dim != target_dim:
|
||||
self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype)
|
||||
else:
|
||||
self.in_proj = nn.Identity()
|
||||
self.rotary_emb = RotaryEmbedding(model_dim//num_heads)
|
||||
self.blocks = nn.ModuleList([
|
||||
TransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers)
|
||||
])
|
||||
self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype)
|
||||
self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None):
|
||||
if target_attention_mask is not None:
|
||||
target_attention_mask = target_attention_mask.to(torch.bool)
|
||||
if target_attention_mask.ndim == 2:
|
||||
target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
if source_attention_mask is not None:
|
||||
source_attention_mask = source_attention_mask.to(torch.bool)
|
||||
if source_attention_mask.ndim == 2:
|
||||
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
x = self.in_proj(self.embed(target_input_ids))
|
||||
context = source_hidden_states
|
||||
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
|
||||
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
|
||||
position_embeddings = self.rotary_emb(x, position_ids)
|
||||
position_embeddings_context = self.rotary_emb(x, position_ids_context)
|
||||
for block in self.blocks:
|
||||
x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
|
||||
return self.norm(self.out_proj(x))
|
||||
|
||||
|
||||
class Anima(MiniTrainDIT):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
|
||||
|
||||
def preprocess_text_embeds(self, text_embeds, text_ids):
|
||||
if text_ids is not None:
|
||||
return self.llm_adapter(text_embeds, text_ids)
|
||||
else:
|
||||
return text_embeds
|
||||
@@ -103,20 +103,10 @@ class AudioPreprocessor:
|
||||
return waveform
|
||||
return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate)
|
||||
|
||||
@staticmethod
|
||||
def normalize_amplitude(
|
||||
waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5
|
||||
) -> torch.Tensor:
|
||||
waveform = waveform - waveform.mean(dim=2, keepdim=True)
|
||||
peak = torch.max(torch.abs(waveform)) + eps
|
||||
scale = peak.clamp(max=max_amplitude) / peak
|
||||
return waveform * scale
|
||||
|
||||
def waveform_to_mel(
|
||||
self, waveform: torch.Tensor, waveform_sample_rate: int, device
|
||||
) -> torch.Tensor:
|
||||
waveform = self.resample(waveform, waveform_sample_rate)
|
||||
waveform = self.normalize_amplitude(waveform)
|
||||
|
||||
mel_transform = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=self.target_sample_rate,
|
||||
@@ -189,9 +179,12 @@ class AudioVAE(torch.nn.Module):
|
||||
waveform = self.device_manager.move_to_load_device(waveform)
|
||||
expected_channels = self.autoencoder.encoder.in_channels
|
||||
if waveform.shape[1] != expected_channels:
|
||||
raise ValueError(
|
||||
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
|
||||
)
|
||||
if waveform.shape[1] == 1:
|
||||
waveform = waveform.expand(-1, expected_channels, *waveform.shape[2:])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
|
||||
)
|
||||
|
||||
mel_spec = self.preprocessor.waveform_to_mel(
|
||||
waveform, waveform_sample_rate, device=self.device_manager.load_device
|
||||
|
||||
@@ -13,10 +13,53 @@ from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
|
||||
|
||||
def modulate(x, scale):
|
||||
return x * (1 + scale.unsqueeze(1))
|
||||
def invert_slices(slices, length):
|
||||
sorted_slices = sorted(slices)
|
||||
result = []
|
||||
current = 0
|
||||
|
||||
for start, end in sorted_slices:
|
||||
if current < start:
|
||||
result.append((current, start))
|
||||
current = max(current, end)
|
||||
|
||||
if current < length:
|
||||
result.append((current, length))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def modulate(x, scale, timestep_zero_index=None):
|
||||
if timestep_zero_index is None:
|
||||
return x * (1 + scale.unsqueeze(1))
|
||||
else:
|
||||
scale = (1 + scale.unsqueeze(1))
|
||||
actual_batch = scale.size(0) // 2
|
||||
slices = timestep_zero_index
|
||||
invert = invert_slices(timestep_zero_index, x.shape[1])
|
||||
for s in slices:
|
||||
x[:, s[0]:s[1]] *= scale[actual_batch:]
|
||||
for s in invert:
|
||||
x[:, s[0]:s[1]] *= scale[:actual_batch]
|
||||
return x
|
||||
|
||||
|
||||
def apply_gate(gate, x, timestep_zero_index=None):
|
||||
if timestep_zero_index is None:
|
||||
return gate * x
|
||||
else:
|
||||
actual_batch = gate.size(0) // 2
|
||||
|
||||
slices = timestep_zero_index
|
||||
invert = invert_slices(timestep_zero_index, x.shape[1])
|
||||
for s in slices:
|
||||
x[:, s[0]:s[1]] *= gate[actual_batch:]
|
||||
for s in invert:
|
||||
x[:, s[0]:s[1]] *= gate[:actual_batch]
|
||||
return x
|
||||
|
||||
#############################################################################
|
||||
# Core NextDiT Model #
|
||||
@@ -258,6 +301,7 @@ class JointTransformerBlock(nn.Module):
|
||||
x_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor]=None,
|
||||
timestep_zero_index=None,
|
||||
transformer_options={},
|
||||
):
|
||||
"""
|
||||
@@ -276,18 +320,18 @@ class JointTransformerBlock(nn.Module):
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
||||
|
||||
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
||||
x = x + apply_gate(gate_msa.unsqueeze(1).tanh(), self.attention_norm2(
|
||||
clamp_fp16(self.attention(
|
||||
modulate(self.attention_norm1(x), scale_msa),
|
||||
modulate(self.attention_norm1(x), scale_msa, timestep_zero_index=timestep_zero_index),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
transformer_options=transformer_options,
|
||||
))
|
||||
))), timestep_zero_index=timestep_zero_index
|
||||
)
|
||||
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||
x = x + apply_gate(gate_mlp.unsqueeze(1).tanh(), self.ffn_norm2(
|
||||
clamp_fp16(self.feed_forward(
|
||||
modulate(self.ffn_norm1(x), scale_mlp),
|
||||
))
|
||||
modulate(self.ffn_norm1(x), scale_mlp, timestep_zero_index=timestep_zero_index),
|
||||
))), timestep_zero_index=timestep_zero_index
|
||||
)
|
||||
else:
|
||||
assert adaln_input is None
|
||||
@@ -345,13 +389,37 @@ class FinalLayer(nn.Module):
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
def forward(self, x, c, timestep_zero_index=None):
|
||||
scale = self.adaLN_modulation(c)
|
||||
x = modulate(self.norm_final(x), scale)
|
||||
x = modulate(self.norm_final(x), scale, timestep_zero_index=timestep_zero_index)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def pad_zimage(feats, pad_token, pad_tokens_multiple):
|
||||
pad_extra = (-feats.shape[1]) % pad_tokens_multiple
|
||||
return torch.cat((feats, pad_token.to(device=feats.device, dtype=feats.dtype, copy=True).unsqueeze(0).repeat(feats.shape[0], pad_extra, 1)), dim=1), pad_extra
|
||||
|
||||
|
||||
def pos_ids_x(start_t, H_tokens, W_tokens, batch_size, device, transformer_options={}):
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
h_scale = 1.0
|
||||
w_scale = 1.0
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
if rope_options is not None:
|
||||
h_scale = rope_options.get("scale_y", 1.0)
|
||||
w_scale = rope_options.get("scale_x", 1.0)
|
||||
|
||||
h_start = rope_options.get("shift_y", 0.0)
|
||||
w_start = rope_options.get("shift_x", 0.0)
|
||||
x_pos_ids = torch.zeros((batch_size, H_tokens * W_tokens, 3), dtype=torch.float32, device=device)
|
||||
x_pos_ids[:, :, 0] = start_t
|
||||
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
return x_pos_ids
|
||||
|
||||
|
||||
class NextDiT(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
@@ -378,6 +446,7 @@ class NextDiT(nn.Module):
|
||||
time_scale=1.0,
|
||||
pad_tokens_multiple=None,
|
||||
clip_text_dim=None,
|
||||
siglip_feat_dim=None,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
@@ -491,6 +560,41 @@ class NextDiT(nn.Module):
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
|
||||
if siglip_feat_dim is not None:
|
||||
self.siglip_embedder = nn.Sequential(
|
||||
operation_settings.get("operations").RMSNorm(siglip_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
operation_settings.get("operations").Linear(
|
||||
siglip_feat_dim,
|
||||
dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
self.siglip_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=False,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.siglip_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
|
||||
else:
|
||||
self.siglip_embedder = None
|
||||
self.siglip_refiner = None
|
||||
self.siglip_pad_token = None
|
||||
|
||||
# This norm final is in the lumina 2.0 code but isn't actually used for anything.
|
||||
# self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
|
||||
@@ -531,70 +635,168 @@ class NextDiT(nn.Module):
|
||||
imgs = torch.stack(imgs, dim=0)
|
||||
return imgs
|
||||
|
||||
def patchify_and_embed(
|
||||
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
||||
bsz = len(x)
|
||||
pH = pW = self.patch_size
|
||||
device = x[0].device
|
||||
orig_x = x
|
||||
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
||||
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
|
||||
def embed_cap(self, cap_feats=None, offset=0, bsz=1, device=None, dtype=None):
|
||||
if cap_feats is not None:
|
||||
cap_feats = self.cap_embedder(cap_feats)
|
||||
cap_feats_len = cap_feats.shape[1]
|
||||
if self.pad_tokens_multiple is not None:
|
||||
cap_feats, _ = pad_zimage(cap_feats, self.cap_pad_token, self.pad_tokens_multiple)
|
||||
else:
|
||||
cap_feats_len = 0
|
||||
cap_feats = self.cap_pad_token.to(device=device, dtype=dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1)
|
||||
|
||||
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
|
||||
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
|
||||
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + offset
|
||||
embeds = (cap_feats,)
|
||||
freqs_cis = (self.rope_embedder(cap_pos_ids).movedim(1, 2),)
|
||||
return embeds, freqs_cis, cap_feats_len
|
||||
|
||||
def embed_all(self, x, cap_feats=None, siglip_feats=None, offset=0, omni=False, transformer_options={}):
|
||||
bsz = 1
|
||||
pH = pW = self.patch_size
|
||||
device = x.device
|
||||
embeds, freqs_cis, cap_feats_len = self.embed_cap(cap_feats, offset=offset, bsz=bsz, device=device, dtype=x.dtype)
|
||||
|
||||
if (not omni) or self.siglip_embedder is None:
|
||||
cap_feats_len = embeds[0].shape[1] + offset
|
||||
embeds += (None,)
|
||||
freqs_cis += (None,)
|
||||
else:
|
||||
cap_feats_len += offset
|
||||
if siglip_feats is not None:
|
||||
b, h, w, c = siglip_feats.shape
|
||||
siglip_feats = siglip_feats.permute(0, 3, 1, 2).reshape(b, h * w, c)
|
||||
siglip_feats = self.siglip_embedder(siglip_feats)
|
||||
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device)
|
||||
siglip_pos_ids[:, :, 0] = cap_feats_len + 2
|
||||
siglip_pos_ids[:, :, 1] = (torch.linspace(0, h * 8 - 1, steps=h, dtype=torch.float32, device=device).floor()).view(-1, 1).repeat(1, w).flatten()
|
||||
siglip_pos_ids[:, :, 2] = (torch.linspace(0, w * 8 - 1, steps=w, dtype=torch.float32, device=device).floor()).view(1, -1).repeat(h, 1).flatten()
|
||||
if self.siglip_pad_token is not None:
|
||||
siglip_feats, pad_extra = pad_zimage(siglip_feats, self.siglip_pad_token, self.pad_tokens_multiple) # TODO: double check
|
||||
siglip_pos_ids = torch.nn.functional.pad(siglip_pos_ids, (0, 0, 0, pad_extra))
|
||||
else:
|
||||
if self.siglip_pad_token is not None:
|
||||
siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1)
|
||||
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device)
|
||||
|
||||
if siglip_feats is None:
|
||||
embeds += (None,)
|
||||
freqs_cis += (None,)
|
||||
else:
|
||||
embeds += (siglip_feats,)
|
||||
freqs_cis += (self.rope_embedder(siglip_pos_ids).movedim(1, 2),)
|
||||
|
||||
B, C, H, W = x.shape
|
||||
x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
h_scale = 1.0
|
||||
w_scale = 1.0
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
if rope_options is not None:
|
||||
h_scale = rope_options.get("scale_y", 1.0)
|
||||
w_scale = rope_options.get("scale_x", 1.0)
|
||||
|
||||
h_start = rope_options.get("shift_y", 0.0)
|
||||
w_start = rope_options.get("shift_x", 0.0)
|
||||
|
||||
H_tokens, W_tokens = H // pH, W // pW
|
||||
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
|
||||
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
|
||||
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
|
||||
x_pos_ids = pos_ids_x(cap_feats_len + 1, H // pH, W // pW, bsz, device, transformer_options=transformer_options)
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
|
||||
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
|
||||
x, pad_extra = pad_zimage(x, self.x_pad_token, self.pad_tokens_multiple)
|
||||
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
|
||||
|
||||
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
||||
embeds += (x,)
|
||||
freqs_cis += (self.rope_embedder(x_pos_ids).movedim(1, 2),)
|
||||
return embeds, freqs_cis, cap_feats_len + len(freqs_cis) - 1
|
||||
|
||||
|
||||
def patchify_and_embed(
|
||||
self, x: torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
||||
bsz = x.shape[0]
|
||||
cap_mask = None # TODO?
|
||||
main_siglip = None
|
||||
orig_x = x
|
||||
|
||||
embeds = ([], [], [])
|
||||
freqs_cis = ([], [], [])
|
||||
leftover_cap = []
|
||||
|
||||
start_t = 0
|
||||
omni = len(ref_latents) > 0
|
||||
if omni:
|
||||
for i, ref in enumerate(ref_latents):
|
||||
if i < len(ref_contexts):
|
||||
ref_con = ref_contexts[i]
|
||||
else:
|
||||
ref_con = None
|
||||
if i < len(siglip_feats):
|
||||
sig_feat = siglip_feats[i]
|
||||
else:
|
||||
sig_feat = None
|
||||
|
||||
out = self.embed_all(ref, ref_con, sig_feat, offset=start_t, omni=omni, transformer_options=transformer_options)
|
||||
for i, e in enumerate(out[0]):
|
||||
if e is not None:
|
||||
embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz))
|
||||
freqs_cis[i].append(out[1][i])
|
||||
start_t = out[2]
|
||||
leftover_cap = ref_contexts[len(ref_latents):]
|
||||
|
||||
H, W = x.shape[-2], x.shape[-1]
|
||||
img_sizes = [(H, W)] * bsz
|
||||
out = self.embed_all(x, cap_feats, main_siglip, offset=start_t, omni=omni, transformer_options=transformer_options)
|
||||
img_len = out[0][-1].shape[1]
|
||||
cap_len = out[0][0].shape[1]
|
||||
for i, e in enumerate(out[0]):
|
||||
if e is not None:
|
||||
e = comfy.utils.repeat_to_batch_size(e, bsz)
|
||||
embeds[i].append(e)
|
||||
freqs_cis[i].append(out[1][i])
|
||||
start_t = out[2]
|
||||
|
||||
for cap in leftover_cap:
|
||||
out = self.embed_cap(cap, offset=start_t, bsz=bsz, device=x.device, dtype=x.dtype)
|
||||
cap_len += out[0][0].shape[1]
|
||||
embeds[0].append(comfy.utils.repeat_to_batch_size(out[0][0], bsz))
|
||||
freqs_cis[0].append(out[1][0])
|
||||
start_t += out[2]
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
|
||||
# refine context
|
||||
cap_feats = torch.cat(embeds[0], dim=1)
|
||||
cap_freqs_cis = torch.cat(freqs_cis[0], dim=1)
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
|
||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
|
||||
|
||||
feats = (cap_feats,)
|
||||
fc = (cap_freqs_cis,)
|
||||
|
||||
if omni and len(embeds[1]) > 0:
|
||||
siglip_mask = None
|
||||
siglip_feats_combined = torch.cat(embeds[1], dim=1)
|
||||
siglip_feats_freqs_cis = torch.cat(freqs_cis[1], dim=1)
|
||||
if self.siglip_refiner is not None:
|
||||
for layer in self.siglip_refiner:
|
||||
siglip_feats_combined = layer(siglip_feats_combined, siglip_mask, siglip_feats_freqs_cis, transformer_options=transformer_options)
|
||||
feats += (siglip_feats_combined,)
|
||||
fc += (siglip_feats_freqs_cis,)
|
||||
|
||||
padded_img_mask = None
|
||||
x = torch.cat(embeds[-1], dim=1)
|
||||
fc_x = torch.cat(freqs_cis[-1], dim=1)
|
||||
if omni:
|
||||
timestep_zero_index = [(x.shape[1] - img_len, x.shape[1])]
|
||||
else:
|
||||
timestep_zero_index = None
|
||||
|
||||
x_input = x
|
||||
for i, layer in enumerate(self.noise_refiner):
|
||||
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
|
||||
x = layer(x, padded_img_mask, fc_x, t, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
|
||||
if "noise_refiner" in patches:
|
||||
for p in patches["noise_refiner"]:
|
||||
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
|
||||
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": fc_x, "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
|
||||
if "img" in out:
|
||||
x = out["img"]
|
||||
|
||||
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
||||
padded_full_embed = torch.cat(feats + (x,), dim=1)
|
||||
if timestep_zero_index is not None:
|
||||
ind = padded_full_embed.shape[1] - x.shape[1]
|
||||
timestep_zero_index = [(ind + x.shape[1] - img_len, ind + x.shape[1])]
|
||||
timestep_zero_index.append((feats[0].shape[1] - cap_len, feats[0].shape[1]))
|
||||
|
||||
mask = None
|
||||
img_sizes = [(H, W)] * bsz
|
||||
l_effective_cap_len = [cap_feats.shape[1]] * bsz
|
||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
||||
l_effective_cap_len = [padded_full_embed.shape[1] - img_len] * bsz
|
||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, torch.cat(fc + (fc_x,), dim=1), timestep_zero_index
|
||||
|
||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
@@ -604,7 +806,11 @@ class NextDiT(nn.Module):
|
||||
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||
|
||||
# def forward(self, x, t, cap_feats, cap_mask):
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs):
|
||||
omni = len(ref_latents) > 0
|
||||
if omni:
|
||||
timesteps = torch.cat([timesteps * 0, timesteps], dim=0)
|
||||
|
||||
t = 1.0 - timesteps
|
||||
cap_feats = context
|
||||
cap_mask = attention_mask
|
||||
@@ -619,8 +825,6 @@ class NextDiT(nn.Module):
|
||||
t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
|
||||
adaln_input = t
|
||||
|
||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||
|
||||
if self.clip_text_pooled_proj is not None:
|
||||
pooled = kwargs.get("clip_text_pooled", None)
|
||||
if pooled is not None:
|
||||
@@ -632,7 +836,7 @@ class NextDiT(nn.Module):
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
|
||||
img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, ref_latents=ref_latents, ref_contexts=ref_contexts, siglip_feats=siglip_feats, transformer_options=transformer_options)
|
||||
freqs_cis = freqs_cis.to(img.device)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.layers)
|
||||
@@ -640,7 +844,7 @@ class NextDiT(nn.Module):
|
||||
img_input = img
|
||||
for i, layer in enumerate(self.layers):
|
||||
transformer_options["block_index"] = i
|
||||
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||
img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||
@@ -649,8 +853,7 @@ class NextDiT(nn.Module):
|
||||
if "txt" in out:
|
||||
img[:, :cap_size[0]] = out["txt"]
|
||||
|
||||
img = self.final_layer(img, adaln_input)
|
||||
img = self.final_layer(img, adaln_input, timestep_zero_index=timestep_zero_index)
|
||||
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
|
||||
|
||||
return -img
|
||||
|
||||
|
||||
@@ -62,6 +62,8 @@ class WanSelfAttention(nn.Module):
|
||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
patches = transformer_options.get("patches", {})
|
||||
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
def qkv_fn_q(x):
|
||||
@@ -86,6 +88,10 @@ class WanSelfAttention(nn.Module):
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
if "attn1_patch" in patches:
|
||||
for p in patches["attn1_patch"]:
|
||||
x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})
|
||||
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
@@ -225,6 +231,8 @@ class WanAttentionBlock(nn.Module):
|
||||
"""
|
||||
# assert e.dtype == torch.float32
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
|
||||
if e.ndim < 4:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||
else:
|
||||
@@ -242,6 +250,11 @@ class WanAttentionBlock(nn.Module):
|
||||
|
||||
# cross-attention & ffn
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
|
||||
if "attn2_patch" in patches:
|
||||
for p in patches["attn2_patch"]:
|
||||
x = p({"x": x, "transformer_options": transformer_options})
|
||||
|
||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||
return x
|
||||
@@ -488,7 +501,7 @@ class WanModel(torch.nn.Module):
|
||||
self.blocks = nn.ModuleList([
|
||||
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||
for _ in range(num_layers)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
|
||||
# head
|
||||
@@ -541,6 +554,7 @@ class WanModel(torch.nn.Module):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
transformer_options["grid_sizes"] = grid_sizes
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
@@ -738,6 +752,7 @@ class VaceWanModel(WanModel):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
transformer_options["grid_sizes"] = grid_sizes
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
|
||||
500
comfy/ldm/wan/model_multitalk.py
Normal file
500
comfy/ldm/wan/model_multitalk.py
Normal file
@@ -0,0 +1,500 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
import comfy
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, split_num=8):
|
||||
scale = 1.0 / visual_q.shape[-1] ** 0.5
|
||||
visual_q = visual_q.transpose(1, 2) * scale
|
||||
|
||||
B, H, x_seqlens, K = visual_q.shape
|
||||
|
||||
x_ref_attn_maps = []
|
||||
for class_idx, ref_target_mask in enumerate(ref_target_masks):
|
||||
ref_target_mask = ref_target_mask.view(1, 1, 1, -1)
|
||||
|
||||
x_ref_attnmap = torch.zeros(B, H, x_seqlens, device=visual_q.device, dtype=visual_q.dtype)
|
||||
chunk_size = min(max(x_seqlens // split_num, 1), x_seqlens)
|
||||
|
||||
for i in range(0, x_seqlens, chunk_size):
|
||||
end_i = min(i + chunk_size, x_seqlens)
|
||||
|
||||
attn_chunk = visual_q[:, :, i:end_i] @ ref_k.permute(0, 2, 3, 1) # B, H, chunk, ref_seqlens
|
||||
|
||||
# Apply softmax
|
||||
attn_max = attn_chunk.max(dim=-1, keepdim=True).values
|
||||
attn_chunk = (attn_chunk - attn_max).exp()
|
||||
attn_sum = attn_chunk.sum(dim=-1, keepdim=True)
|
||||
attn_chunk = attn_chunk / (attn_sum + 1e-8)
|
||||
|
||||
# Apply mask and sum
|
||||
masked_attn = attn_chunk * ref_target_mask
|
||||
x_ref_attnmap[:, :, i:end_i] = masked_attn.sum(-1) / (ref_target_mask.sum() + 1e-8)
|
||||
|
||||
del attn_chunk, masked_attn
|
||||
|
||||
# Average across heads
|
||||
x_ref_attnmap = x_ref_attnmap.mean(dim=1) # B, x_seqlens
|
||||
x_ref_attn_maps.append(x_ref_attnmap)
|
||||
|
||||
del visual_q, ref_k
|
||||
|
||||
return torch.cat(x_ref_attn_maps, dim=0)
|
||||
|
||||
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2):
|
||||
"""Args:
|
||||
query (torch.tensor): B M H K
|
||||
key (torch.tensor): B M H K
|
||||
shape (tuple): (N_t, N_h, N_w)
|
||||
ref_target_masks: [B, N_h * N_w]
|
||||
"""
|
||||
|
||||
N_t, N_h, N_w = shape
|
||||
|
||||
x_seqlens = N_h * N_w
|
||||
ref_k = ref_k[:, :x_seqlens]
|
||||
_, seq_lens, heads, _ = visual_q.shape
|
||||
class_num, _ = ref_target_masks.shape
|
||||
x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q)
|
||||
|
||||
split_chunk = heads // split_num
|
||||
|
||||
for i in range(split_num):
|
||||
x_ref_attn_maps_perhead = calculate_x_ref_attn_map(
|
||||
visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :],
|
||||
ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :],
|
||||
ref_target_masks
|
||||
)
|
||||
x_ref_attn_maps += x_ref_attn_maps_perhead
|
||||
|
||||
return x_ref_attn_maps / split_num
|
||||
|
||||
|
||||
def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):
|
||||
source_min, source_max = source_range
|
||||
new_min, new_max = target_range
|
||||
normalized = (column - source_min) / (source_max - source_min + epsilon)
|
||||
scaled = normalized * (new_max - new_min) + new_min
|
||||
return scaled
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||
x1, x2 = x.unbind(dim=-1)
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return rearrange(x, "... d r -> ... (d r)")
|
||||
|
||||
|
||||
def get_audio_embeds(encoded_audio, audio_start, audio_end):
|
||||
audio_embs = []
|
||||
human_num = len(encoded_audio)
|
||||
audio_frames = encoded_audio[0].shape[0]
|
||||
|
||||
indices = (torch.arange(4 + 1) - 2) * 1
|
||||
|
||||
for human_idx in range(human_num):
|
||||
if audio_end > audio_frames: # in case of not enough audio for current window, pad with first audio frame as that's most likely silence
|
||||
pad_len = audio_end - audio_frames
|
||||
pad_shape = list(encoded_audio[human_idx].shape)
|
||||
pad_shape[0] = pad_len
|
||||
pad_tensor = encoded_audio[human_idx][:1].repeat(pad_len, *([1] * (encoded_audio[human_idx].dim() - 1)))
|
||||
encoded_audio_in = torch.cat([encoded_audio[human_idx], pad_tensor], dim=0)
|
||||
else:
|
||||
encoded_audio_in = encoded_audio[human_idx]
|
||||
center_indices = torch.arange(audio_start, audio_end, 1).unsqueeze(1) + indices.unsqueeze(0)
|
||||
center_indices = torch.clamp(center_indices, min=0, max=encoded_audio_in.shape[0] - 1)
|
||||
audio_emb = encoded_audio_in[center_indices].unsqueeze(0)
|
||||
audio_embs.append(audio_emb)
|
||||
|
||||
return torch.cat(audio_embs, dim=0)
|
||||
|
||||
|
||||
def project_audio_features(audio_proj, encoded_audio, audio_start, audio_end):
|
||||
audio_embs = get_audio_embeds(encoded_audio, audio_start, audio_end)
|
||||
|
||||
first_frame_audio_emb_s = audio_embs[:, :1, ...]
|
||||
latter_frame_audio_emb = audio_embs[:, 1:, ...]
|
||||
latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=4)
|
||||
|
||||
middle_index = audio_proj.seq_len // 2
|
||||
|
||||
latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...]
|
||||
latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||
latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...]
|
||||
latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||
latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...]
|
||||
latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||
latter_frame_audio_emb_s = torch.cat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2)
|
||||
|
||||
audio_emb = audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s)
|
||||
audio_emb = torch.cat(audio_emb.split(1), dim=2)
|
||||
|
||||
return audio_emb
|
||||
|
||||
|
||||
class RotaryPositionalEmbedding1D(torch.nn.Module):
|
||||
def __init__(self,
|
||||
head_dim,
|
||||
):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.base = 10000
|
||||
|
||||
def precompute_freqs_cis_1d(self, pos_indices):
|
||||
freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
|
||||
freqs = freqs.to(pos_indices.device)
|
||||
freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs)
|
||||
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
||||
return freqs
|
||||
|
||||
def forward(self, x, pos_indices):
|
||||
freqs_cis = self.precompute_freqs_cis_1d(pos_indices)
|
||||
|
||||
x_ = x.float()
|
||||
|
||||
freqs_cis = freqs_cis.float().to(x.device)
|
||||
cos, sin = freqs_cis.cos(), freqs_cis.sin()
|
||||
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
|
||||
x_ = (x_ * cos) + (rotate_half(x_) * sin)
|
||||
|
||||
return x_.type_as(x)
|
||||
|
||||
class SingleStreamAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
encoder_hidden_states_dim: int,
|
||||
num_heads: int,
|
||||
qkv_bias: bool,
|
||||
device=None, dtype=None, operations=None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.encoder_hidden_states_dim = encoder_hidden_states_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.q_linear = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.kv_linear = operations.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None) -> torch.Tensor:
|
||||
N_t, N_h, N_w = shape
|
||||
|
||||
expected_tokens = N_t * N_h * N_w
|
||||
actual_tokens = x.shape[1]
|
||||
x_extra = None
|
||||
|
||||
if actual_tokens != expected_tokens:
|
||||
x_extra = x[:, -N_h * N_w:, :]
|
||||
x = x[:, :-N_h * N_w, :]
|
||||
N_t = N_t - 1
|
||||
|
||||
B = x.shape[0]
|
||||
S = N_h * N_w
|
||||
x = x.view(B * N_t, S, self.dim)
|
||||
|
||||
# get q for hidden_state
|
||||
q = self.q_linear(x).view(B * N_t, S, self.num_heads, self.head_dim)
|
||||
|
||||
# get kv from encoder_hidden_states # shape: (B, N, num_heads, head_dim)
|
||||
kv = self.kv_linear(encoder_hidden_states)
|
||||
encoder_k, encoder_v = kv.view(B * N_t, encoder_hidden_states.shape[1], 2, self.num_heads, self.head_dim).unbind(2)
|
||||
|
||||
#print("q.shape", q.shape) #torch.Size([21, 1024, 40, 128])
|
||||
x = optimized_attention(
|
||||
q.transpose(1, 2),
|
||||
encoder_k.transpose(1, 2),
|
||||
encoder_v.transpose(1, 2),
|
||||
heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2)
|
||||
|
||||
# linear transform
|
||||
x = self.proj(x.reshape(B * N_t, S, self.dim))
|
||||
x = x.view(B, N_t * S, self.dim)
|
||||
|
||||
if x_extra is not None:
|
||||
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
class SingleStreamMultiAttention(SingleStreamAttention):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
encoder_hidden_states_dim: int,
|
||||
num_heads: int,
|
||||
qkv_bias: bool,
|
||||
class_range: int = 24,
|
||||
class_interval: int = 4,
|
||||
device=None, dtype=None, operations=None
|
||||
) -> None:
|
||||
super().__init__(
|
||||
dim=dim,
|
||||
encoder_hidden_states_dim=encoder_hidden_states_dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations
|
||||
)
|
||||
|
||||
# Rotary-embedding layout parameters
|
||||
self.class_interval = class_interval
|
||||
self.class_range = class_range
|
||||
self.max_humans = self.class_range // self.class_interval
|
||||
|
||||
# Constant bucket used for background tokens
|
||||
self.rope_bak = int(self.class_range // 2)
|
||||
|
||||
self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
shape=None,
|
||||
x_ref_attn_map=None
|
||||
) -> torch.Tensor:
|
||||
encoder_hidden_states = encoder_hidden_states.squeeze(0).to(x.device)
|
||||
human_num = x_ref_attn_map.shape[0] if x_ref_attn_map is not None else 1
|
||||
# Single-speaker fall-through
|
||||
if human_num <= 1:
|
||||
return super().forward(x, encoder_hidden_states, shape)
|
||||
|
||||
N_t, N_h, N_w = shape
|
||||
|
||||
x_extra = None
|
||||
if x.shape[0] * N_t != encoder_hidden_states.shape[0]:
|
||||
x_extra = x[:, -N_h * N_w:, :]
|
||||
x = x[:, :-N_h * N_w, :]
|
||||
N_t = N_t - 1
|
||||
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
|
||||
|
||||
# Query projection
|
||||
B, N, C = x.shape
|
||||
q = self.q_linear(x)
|
||||
q = q.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
||||
|
||||
# Use `class_range` logic for 2 speakers
|
||||
rope_h1 = (0, self.class_interval)
|
||||
rope_h2 = (self.class_range - self.class_interval, self.class_range)
|
||||
rope_bak = int(self.class_range // 2)
|
||||
|
||||
# Normalize and scale attention maps for each speaker
|
||||
max_values = x_ref_attn_map.max(1).values[:, None, None]
|
||||
min_values = x_ref_attn_map.min(1).values[:, None, None]
|
||||
max_min_values = torch.cat([max_values, min_values], dim=2)
|
||||
|
||||
human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min()
|
||||
human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min()
|
||||
|
||||
human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), rope_h1)
|
||||
human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), rope_h2)
|
||||
back = torch.full((x_ref_attn_map.size(1),), rope_bak, dtype=human1.dtype, device=human1.device)
|
||||
|
||||
# Token-wise speaker dominance
|
||||
max_indices = x_ref_attn_map.argmax(dim=0)
|
||||
normalized_map = torch.stack([human1, human2, back], dim=1)
|
||||
normalized_pos = normalized_map[torch.arange(x_ref_attn_map.size(1)), max_indices]
|
||||
|
||||
# Apply rotary to Q
|
||||
q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
|
||||
q = self.rope_1d(q, normalized_pos)
|
||||
q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
|
||||
|
||||
# Keys / Values
|
||||
_, N_a, _ = encoder_hidden_states.shape
|
||||
encoder_kv = self.kv_linear(encoder_hidden_states)
|
||||
encoder_kv = encoder_kv.view(B, N_a, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
encoder_k, encoder_v = encoder_kv.unbind(0)
|
||||
|
||||
# Rotary for keys – assign centre of each speaker bucket to its context tokens
|
||||
per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device)
|
||||
per_frame[: per_frame.size(0) // 2] = (rope_h1[0] + rope_h1[1]) / 2
|
||||
per_frame[per_frame.size(0) // 2 :] = (rope_h2[0] + rope_h2[1]) / 2
|
||||
encoder_pos = torch.cat([per_frame] * N_t, dim=0)
|
||||
|
||||
encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
|
||||
encoder_k = self.rope_1d(encoder_k, encoder_pos)
|
||||
encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
|
||||
|
||||
# Final attention
|
||||
q = rearrange(q, "B H M K -> B M H K")
|
||||
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
|
||||
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
|
||||
|
||||
x = optimized_attention(
|
||||
q.transpose(1, 2),
|
||||
encoder_k.transpose(1, 2),
|
||||
encoder_v.transpose(1, 2),
|
||||
heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2)
|
||||
|
||||
# Linear projection
|
||||
x = x.reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
|
||||
# Restore original layout
|
||||
x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t)
|
||||
if x_extra is not None:
|
||||
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MultiTalkAudioProjModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
seq_len: int = 5,
|
||||
seq_len_vf: int = 12,
|
||||
blocks: int = 12,
|
||||
channels: int = 768,
|
||||
intermediate_dim: int = 512,
|
||||
out_dim: int = 768,
|
||||
context_tokens: int = 32,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.seq_len = seq_len
|
||||
self.blocks = blocks
|
||||
self.channels = channels
|
||||
self.input_dim = seq_len * blocks * channels
|
||||
self.input_dim_vf = seq_len_vf * blocks * channels
|
||||
self.intermediate_dim = intermediate_dim
|
||||
self.context_tokens = context_tokens
|
||||
self.out_dim = out_dim
|
||||
|
||||
# define multiple linear layers
|
||||
self.proj1 = operations.Linear(self.input_dim, intermediate_dim, device=device, dtype=dtype)
|
||||
self.proj1_vf = operations.Linear(self.input_dim_vf, intermediate_dim, device=device, dtype=dtype)
|
||||
self.proj2 = operations.Linear(intermediate_dim, intermediate_dim, device=device, dtype=dtype)
|
||||
self.proj3 = operations.Linear(intermediate_dim, context_tokens * out_dim, device=device, dtype=dtype)
|
||||
self.norm = operations.LayerNorm(out_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, audio_embeds, audio_embeds_vf):
|
||||
video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1]
|
||||
B, _, _, S, C = audio_embeds.shape
|
||||
|
||||
# process audio of first frame
|
||||
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
|
||||
batch_size, window_size, blocks, channels = audio_embeds.shape
|
||||
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
|
||||
|
||||
# process audio of latter frame
|
||||
audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c")
|
||||
batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape
|
||||
audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf)
|
||||
|
||||
# first projection
|
||||
audio_embeds = torch.relu(self.proj1(audio_embeds))
|
||||
audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf))
|
||||
audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B)
|
||||
audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B)
|
||||
audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1)
|
||||
batch_size_c, N_t, C_a = audio_embeds_c.shape
|
||||
audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a)
|
||||
|
||||
# second projection
|
||||
audio_embeds_c = torch.relu(self.proj2(audio_embeds_c))
|
||||
|
||||
context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.out_dim)
|
||||
|
||||
# normalization and reshape
|
||||
context_tokens = self.norm(context_tokens)
|
||||
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
|
||||
|
||||
return context_tokens
|
||||
|
||||
|
||||
class WanMultiTalkAttentionBlock(torch.nn.Module):
|
||||
def __init__(self, in_dim=5120, out_dim=768, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.audio_cross_attn = SingleStreamMultiAttention(in_dim, out_dim, num_heads=40, qkv_bias=True, device=device, dtype=dtype, operations=operations)
|
||||
self.norm_x = operations.LayerNorm(in_dim, device=device, dtype=dtype, elementwise_affine=True)
|
||||
|
||||
|
||||
class MultiTalkGetAttnMapPatch:
|
||||
def __init__(self, ref_target_masks=None):
|
||||
self.ref_target_masks = ref_target_masks
|
||||
|
||||
def __call__(self, kwargs):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
x = kwargs["x"]
|
||||
|
||||
if self.ref_target_masks is not None:
|
||||
x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device))
|
||||
transformer_options["x_ref_attn_map"] = x_ref_attn_map
|
||||
return x
|
||||
|
||||
|
||||
class MultiTalkCrossAttnPatch:
|
||||
def __init__(self, model_patch, audio_scale=1.0, ref_target_masks=None):
|
||||
self.model_patch = model_patch
|
||||
self.audio_scale = audio_scale
|
||||
self.ref_target_masks = ref_target_masks
|
||||
|
||||
def __call__(self, kwargs):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
block_idx = transformer_options.get("block_index", None)
|
||||
x = kwargs["x"]
|
||||
if block_idx is None:
|
||||
return torch.zeros_like(x)
|
||||
|
||||
audio_embeds = transformer_options.get("audio_embeds")
|
||||
x_ref_attn_map = transformer_options.pop("x_ref_attn_map", None)
|
||||
|
||||
norm_x = self.model_patch.model.blocks[block_idx].norm_x(x)
|
||||
x_audio = self.model_patch.model.blocks[block_idx].audio_cross_attn(
|
||||
norm_x, audio_embeds.to(x.dtype),
|
||||
shape=transformer_options["grid_sizes"],
|
||||
x_ref_attn_map=x_ref_attn_map
|
||||
)
|
||||
x = x + x_audio * self.audio_scale
|
||||
return x
|
||||
|
||||
def models(self):
|
||||
return [self.model_patch]
|
||||
|
||||
class MultiTalkApplyModelWrapper:
|
||||
def __init__(self, init_latents):
|
||||
self.init_latents = init_latents
|
||||
|
||||
def __call__(self, executor, x, *args, **kwargs):
|
||||
x[:, :, :self.init_latents.shape[2]] = self.init_latents.to(x)
|
||||
samples = executor(x, *args, **kwargs)
|
||||
return samples
|
||||
|
||||
|
||||
class InfiniteTalkOuterSampleWrapper:
|
||||
def __init__(self, motion_frames_latent, model_patch, is_extend=False):
|
||||
self.motion_frames_latent = motion_frames_latent
|
||||
self.model_patch = model_patch
|
||||
self.is_extend = is_extend
|
||||
|
||||
def __call__(self, executor, *args, **kwargs):
|
||||
model_patcher = executor.class_obj.model_patcher
|
||||
model_options = executor.class_obj.model_options
|
||||
process_latent_in = model_patcher.model.process_latent_in
|
||||
|
||||
# for InfiniteTalk, model input first latent(s) need to always be replaced on every step
|
||||
if self.motion_frames_latent is not None:
|
||||
wrappers = model_options["transformer_options"]["wrappers"]
|
||||
w = wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {})
|
||||
w["MultiTalk_apply_model"] = [MultiTalkApplyModelWrapper(process_latent_in(self.motion_frames_latent))]
|
||||
|
||||
# run the sampling process
|
||||
result = executor(*args, **kwargs)
|
||||
|
||||
# insert motion frames before decoding
|
||||
if self.is_extend:
|
||||
overlap = self.motion_frames_latent.shape[2]
|
||||
result = torch.cat([self.motion_frames_latent.to(result), result[:, :, overlap:]], dim=2)
|
||||
|
||||
return result
|
||||
|
||||
def to(self, device_or_dtype):
|
||||
if isinstance(device_or_dtype, torch.device):
|
||||
if self.motion_frames_latent is not None:
|
||||
self.motion_frames_latent = self.motion_frames_latent.to(device_or_dtype)
|
||||
return self
|
||||
@@ -49,6 +49,7 @@ import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -1147,9 +1148,31 @@ class CosmosPredict2(BaseModel):
|
||||
sigma = (sigma / (sigma + 1))
|
||||
return latent_image / (1.0 - sigma)
|
||||
|
||||
class Anima(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.anima.model.Anima)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
t5xxl_ids = kwargs.get("t5xxl_ids", None)
|
||||
t5xxl_weights = kwargs.get("t5xxl_weights", None)
|
||||
device = kwargs["device"]
|
||||
if cross_attn is not None:
|
||||
if t5xxl_ids is not None:
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
|
||||
if t5xxl_weights is not None:
|
||||
cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
|
||||
|
||||
if cross_attn.shape[1] < 512:
|
||||
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class Lumina2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
@@ -1169,6 +1192,35 @@ class Lumina2(BaseModel):
|
||||
if clip_text_pooled is not None:
|
||||
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||
|
||||
clip_vision_outputs = kwargs.get("clip_vision_outputs", list(map(lambda a: a.get("clip_vision_output"), kwargs.get("unclip_conditioning", [{}])))) # Z Image omni
|
||||
if clip_vision_outputs is not None and len(clip_vision_outputs) > 0:
|
||||
sigfeats = []
|
||||
for clip_vision_output in clip_vision_outputs:
|
||||
if clip_vision_output is not None:
|
||||
image_size = clip_vision_output.image_sizes[0]
|
||||
shape = clip_vision_output.last_hidden_state.shape
|
||||
sigfeats.append(clip_vision_output.last_hidden_state.reshape(shape[0], image_size[1] // 16, image_size[2] // 16, shape[-1]))
|
||||
if len(sigfeats) > 0:
|
||||
out['siglip_feats'] = comfy.conds.CONDList(sigfeats)
|
||||
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
latents = []
|
||||
for lat in ref_latents:
|
||||
latents.append(self.process_latent_in(lat))
|
||||
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||
|
||||
ref_contexts = kwargs.get("reference_latents_text_embeds", None)
|
||||
if ref_contexts is not None:
|
||||
out['ref_contexts'] = comfy.conds.CONDList(ref_contexts)
|
||||
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
||||
return out
|
||||
|
||||
class WAN21(BaseModel):
|
||||
|
||||
@@ -253,7 +253,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["image_model"] = "chroma_radiance"
|
||||
dit_config["in_channels"] = 3
|
||||
dit_config["out_channels"] = 3
|
||||
dit_config["patch_size"] = 16
|
||||
dit_config["patch_size"] = state_dict.get('{}img_in_patch.weight'.format(key_prefix)).size(dim=-1)
|
||||
dit_config["nerf_hidden_size"] = 64
|
||||
dit_config["nerf_mlp_ratio"] = 4
|
||||
dit_config["nerf_depth"] = 4
|
||||
@@ -446,6 +446,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["time_scale"] = 1000.0
|
||||
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["pad_tokens_multiple"] = 32
|
||||
sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None)
|
||||
if sig_weight is not None:
|
||||
dit_config["siglip_feat_dim"] = sig_weight.shape[0]
|
||||
|
||||
return dit_config
|
||||
|
||||
@@ -547,6 +550,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "cosmos_predict2"
|
||||
if "{}llm_adapter.blocks.0.cross_attn.q_proj.weight".format(key_prefix) in state_dict_keys:
|
||||
dit_config["image_model"] = "anima"
|
||||
dit_config["max_img_h"] = 240
|
||||
dit_config["max_img_w"] = 240
|
||||
dit_config["max_frames"] = 128
|
||||
|
||||
@@ -104,7 +104,7 @@ class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
||||
needs_padding = padded_shape != orig_shape
|
||||
|
||||
if stochastic_rounding > 0:
|
||||
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding)
|
||||
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4_by_block(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding)
|
||||
else:
|
||||
qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding)
|
||||
|
||||
|
||||
27
comfy/sd.py
27
comfy/sd.py
@@ -57,6 +57,7 @@ import comfy.text_encoders.ovis
|
||||
import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.jina_clip_2
|
||||
import comfy.text_encoders.newbie
|
||||
import comfy.text_encoders.anima
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -635,14 +636,13 @@ class VAE:
|
||||
self.upscale_index_formula = (4, 16, 16)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
|
||||
self.downscale_index_formula = (4, 16, 16)
|
||||
if self.latent_channels == 48: # Wan 2.2
|
||||
if self.latent_channels in [48, 128]: # Wan 2.2 and LTX2
|
||||
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling
|
||||
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
|
||||
self.process_input = self.process_output = lambda image: image
|
||||
self.process_output = lambda image: image
|
||||
self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
|
||||
elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
|
||||
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15)
|
||||
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
|
||||
self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
|
||||
else:
|
||||
if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
|
||||
@@ -1014,6 +1014,7 @@ class CLIPType(Enum):
|
||||
KANDINSKY5 = 22
|
||||
KANDINSKY5_IMAGE = 23
|
||||
NEWBIE = 24
|
||||
FLUX2 = 25
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
@@ -1046,6 +1047,8 @@ class TEModel(Enum):
|
||||
QWEN3_2B = 17
|
||||
GEMMA_3_12B = 18
|
||||
JINA_CLIP_2 = 19
|
||||
QWEN3_8B = 20
|
||||
QWEN3_06B = 21
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@@ -1089,6 +1092,10 @@ def detect_te_model(sd):
|
||||
return TEModel.QWEN3_4B
|
||||
elif weight.shape[0] == 2048:
|
||||
return TEModel.QWEN3_2B
|
||||
elif weight.shape[0] == 4096:
|
||||
return TEModel.QWEN3_8B
|
||||
elif weight.shape[0] == 1024:
|
||||
return TEModel.QWEN3_06B
|
||||
if weight.shape[0] == 5120:
|
||||
if "model.layers.39.post_attention_layernorm.weight" in sd:
|
||||
return TEModel.MISTRAL3_24B
|
||||
@@ -1214,14 +1221,24 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
|
||||
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
|
||||
elif te_model == TEModel.QWEN3_4B:
|
||||
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
|
||||
if clip_type == CLIPType.FLUX or clip_type == CLIPType.FLUX2:
|
||||
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_4b")
|
||||
clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer
|
||||
else:
|
||||
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
|
||||
elif te_model == TEModel.QWEN3_2B:
|
||||
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
|
||||
elif te_model == TEModel.QWEN3_8B:
|
||||
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_8b")
|
||||
clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B
|
||||
elif te_model == TEModel.JINA_CLIP_2:
|
||||
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
|
||||
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
|
||||
elif te_model == TEModel.QWEN3_06B:
|
||||
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
|
||||
else:
|
||||
# clip_l
|
||||
if clip_type == CLIPType.SD3:
|
||||
|
||||
@@ -23,6 +23,7 @@ import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.z_image
|
||||
import comfy.text_encoders.anima
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@@ -763,7 +764,7 @@ class Flux2(Flux):
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36
|
||||
self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * (unet_config['hidden_size'] / 2604)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Flux2(self, device=device)
|
||||
@@ -992,6 +993,36 @@ class CosmosT2IPredict2(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
|
||||
|
||||
class Anima(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "anima",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Wan21
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Anima(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect))
|
||||
|
||||
class CosmosI2VPredict2(CosmosT2IPredict2):
|
||||
unet_config = {
|
||||
"image_model": "cosmos_predict2",
|
||||
@@ -1042,7 +1073,7 @@ class ZImage(Lumina2):
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 2.0
|
||||
memory_usage_factor = 2.8
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
@@ -1551,6 +1582,6 @@ class Kandinsky5Image(Kandinsky5):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -112,7 +112,8 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
||||
|
||||
|
||||
class TAEHV(nn.Module):
|
||||
def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True):
|
||||
def __init__(self, latent_channels, parallel=False, encoder_time_downscale=(True, True, False), decoder_time_upscale=(False, True, True), decoder_space_upscale=(True, True, True),
|
||||
latent_format=None, show_progress_bar=False):
|
||||
super().__init__()
|
||||
self.image_channels = 3
|
||||
self.patch_size = 1
|
||||
@@ -124,6 +125,9 @@ class TAEHV(nn.Module):
|
||||
self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x)
|
||||
if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
|
||||
self.patch_size = 2
|
||||
elif self.latent_channels == 128: # LTX2
|
||||
self.patch_size, self.latent_channels, encoder_time_downscale, decoder_time_upscale = 4, 128, (True, True, True), (True, True, True)
|
||||
|
||||
if self.latent_channels == 32: # HunyuanVideo1.5
|
||||
act_func = nn.LeakyReLU(0.2, inplace=True)
|
||||
else: # HunyuanVideo, Wan 2.1
|
||||
@@ -131,41 +135,52 @@ class TAEHV(nn.Module):
|
||||
|
||||
self.encoder = nn.Sequential(
|
||||
conv(self.image_channels*self.patch_size**2, 64), act_func,
|
||||
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 2 if encoder_time_downscale[0] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 2 if encoder_time_downscale[1] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 2 if encoder_time_downscale[2] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
conv(64, self.latent_channels),
|
||||
)
|
||||
n_f = [256, 128, 64, 64]
|
||||
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
|
||||
|
||||
self.decoder = nn.Sequential(
|
||||
Clamp(), conv(self.latent_channels, n_f[0]), act_func,
|
||||
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
|
||||
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
|
||||
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
|
||||
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 2 if decoder_time_upscale[0] else 1), conv(n_f[0], n_f[1], bias=False),
|
||||
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[1] else 1), conv(n_f[1], n_f[2], bias=False),
|
||||
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[2] else 1), conv(n_f[2], n_f[3], bias=False),
|
||||
act_func, conv(n_f[3], self.image_channels*self.patch_size**2),
|
||||
)
|
||||
@property
|
||||
def show_progress_bar(self):
|
||||
return self._show_progress_bar
|
||||
|
||||
@show_progress_bar.setter
|
||||
def show_progress_bar(self, value):
|
||||
self._show_progress_bar = value
|
||||
self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool))
|
||||
self.t_upscale = 2**sum(t.stride == 2 for t in self.decoder if isinstance(t, TGrow))
|
||||
self.frames_to_trim = self.t_upscale - 1
|
||||
self._show_progress_bar = show_progress_bar
|
||||
|
||||
@property
|
||||
def show_progress_bar(self):
|
||||
return self._show_progress_bar
|
||||
|
||||
@show_progress_bar.setter
|
||||
def show_progress_bar(self, value):
|
||||
self._show_progress_bar = value
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
if self.patch_size > 1:
|
||||
x = F.pixel_unshuffle(x, self.patch_size)
|
||||
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
if x.shape[1] % 4 != 0:
|
||||
# pad at end to multiple of 4
|
||||
n_pad = 4 - x.shape[1] % 4
|
||||
if self.patch_size > 1:
|
||||
B, T, C, H, W = x.shape
|
||||
x = x.reshape(B * T, C, H, W)
|
||||
x = F.pixel_unshuffle(x, self.patch_size)
|
||||
x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size)
|
||||
if x.shape[1] % self.t_downscale != 0:
|
||||
# pad at end to multiple of t_downscale
|
||||
n_pad = self.t_downscale - x.shape[1] % self.t_downscale
|
||||
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
|
||||
x = torch.cat([x, padding], 1)
|
||||
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1)
|
||||
return self.process_out(x)
|
||||
|
||||
def decode(self, x, **kwargs):
|
||||
x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W]
|
||||
x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W]
|
||||
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
|
||||
if self.patch_size > 1:
|
||||
|
||||
61
comfy/text_encoders/anima.py
Normal file
61
comfy/text_encoders/anima.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from transformers import Qwen2Tokenizer, T5TokenizerFast
|
||||
import comfy.text_encoders.llama
|
||||
from comfy import sd1_clip
|
||||
import os
|
||||
import torch
|
||||
|
||||
|
||||
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data)
|
||||
|
||||
class AnimaTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.qwen3_06b = Qwen3Tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["qwen3_06b"] = [[(token, 1.0) for token, _ in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.t5xxl.untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
|
||||
class Qwen3_06BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_06B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class AnimaTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen3_06b", clip_model=Qwen3_06BModel, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
out = super().encode_token_weights(token_weight_pairs)
|
||||
out[2]["t5xxl_ids"] = torch.tensor(list(map(lambda a: a[0], token_weight_pairs["t5xxl"][0])), dtype=torch.int)
|
||||
out[2]["t5xxl_weights"] = torch.tensor(list(map(lambda a: a[1], token_weight_pairs["t5xxl"][0])))
|
||||
return out
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class AnimaTEModel_(AnimaTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return AnimaTEModel_
|
||||
@@ -3,7 +3,7 @@ import comfy.text_encoders.t5
|
||||
import comfy.text_encoders.sd3_clip
|
||||
import comfy.text_encoders.llama
|
||||
import comfy.model_management
|
||||
from transformers import T5TokenizerFast, LlamaTokenizerFast
|
||||
from transformers import T5TokenizerFast, LlamaTokenizerFast, Qwen2Tokenizer
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
@@ -172,3 +172,60 @@ def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False):
|
||||
model_options["num_layers"] = 30
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return Flux2TEModel_
|
||||
|
||||
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
class Qwen3Tokenizer8B(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
class KleinTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_4b"):
|
||||
if name == "qwen3_4b":
|
||||
tokenizer = Qwen3Tokenizer
|
||||
elif name == "qwen3_8b":
|
||||
tokenizer = Qwen3Tokenizer8B
|
||||
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=name, tokenizer=tokenizer)
|
||||
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
||||
if llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
|
||||
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||
return tokens
|
||||
|
||||
class KleinTokenizer8B(KleinTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_8b"):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=name)
|
||||
|
||||
class Qwen3_4BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer=[9, 18, 27], layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
class Qwen3_8BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer=[9, 18, 27], layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_8B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
def klein_te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen3_4b"):
|
||||
if model_type == "qwen3_4b":
|
||||
model = Qwen3_4BModel
|
||||
elif model_type == "qwen3_8b":
|
||||
model = Qwen3_8BModel
|
||||
|
||||
class Flux2TEModel_(Flux2TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
|
||||
return Flux2TEModel_
|
||||
|
||||
@@ -77,6 +77,28 @@ class Qwen25_3BConfig:
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Qwen3_06BConfig:
|
||||
vocab_size: int = 151936
|
||||
hidden_size: int = 1024
|
||||
intermediate_size: int = 3072
|
||||
num_hidden_layers: int = 28
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 32768
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Qwen3_4BConfig:
|
||||
vocab_size: int = 151936
|
||||
@@ -99,6 +121,28 @@ class Qwen3_4BConfig:
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Qwen3_8BConfig:
|
||||
vocab_size: int = 151936
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 12288
|
||||
num_hidden_layers: int = 36
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 40960
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Ovis25_2BConfig:
|
||||
vocab_size: int = 151936
|
||||
@@ -619,6 +663,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_06B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_06BConfig(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_4B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
@@ -628,6 +681,15 @@ class Qwen3_4B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_8B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_8BConfig(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Ovis25_2B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
@@ -119,7 +119,17 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
if len(sdo) == 0:
|
||||
sdo = sd
|
||||
|
||||
return self.load_state_dict(sdo, strict=False)
|
||||
missing_all = []
|
||||
unexpected_all = []
|
||||
|
||||
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]:
|
||||
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
|
||||
if component_sd:
|
||||
missing, unexpected = component.load_state_dict(component_sd, strict=False)
|
||||
missing_all.extend([f"{prefix}{k}" for k in missing])
|
||||
unexpected_all.extend([f"{prefix}{k}" for k in unexpected])
|
||||
|
||||
return (missing_all, unexpected_all)
|
||||
|
||||
def memory_estimation_function(self, token_weight_pairs, device=None):
|
||||
constant = 6.0
|
||||
|
||||
@@ -61,6 +61,7 @@ def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return OvisTEModel_
|
||||
|
||||
@@ -40,6 +40,7 @@ def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return ZImageTEModel_
|
||||
|
||||
@@ -30,6 +30,7 @@ from torch.nn.functional import interpolate
|
||||
from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
import json
|
||||
import time
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
DISABLE_MMAP = args.disable_mmap
|
||||
@@ -610,6 +611,14 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
|
||||
"ff_context.net.2.weight": "txt_mlp.2.weight",
|
||||
"ff_context.net.2.bias": "txt_mlp.2.bias",
|
||||
"ff.linear_in.weight": "img_mlp.0.weight", # LyCoris LoKr
|
||||
"ff.linear_in.bias": "img_mlp.0.bias",
|
||||
"ff.linear_out.weight": "img_mlp.2.weight",
|
||||
"ff.linear_out.bias": "img_mlp.2.bias",
|
||||
"ff_context.linear_in.weight": "txt_mlp.0.weight",
|
||||
"ff_context.linear_in.bias": "txt_mlp.0.bias",
|
||||
"ff_context.linear_out.weight": "txt_mlp.2.weight",
|
||||
"ff_context.linear_out.bias": "txt_mlp.2.bias",
|
||||
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
||||
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
||||
@@ -638,6 +647,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
"proj_out.bias": "linear2.bias",
|
||||
"attn.norm_q.weight": "norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "norm.key_norm.scale",
|
||||
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
|
||||
"attn.to_out.weight": "linear2.weight", # Flux 2
|
||||
}
|
||||
|
||||
for k in block_map:
|
||||
@@ -928,7 +939,9 @@ def bislerp(samples, width, height):
|
||||
return result.to(orig_dtype)
|
||||
|
||||
def lanczos(samples, width, height):
|
||||
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
||||
#the below API is strict and expects grayscale to be squeezed
|
||||
samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1)
|
||||
images = [Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
||||
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
||||
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
|
||||
result = torch.stack(images)
|
||||
@@ -1097,6 +1110,10 @@ def set_progress_bar_global_hook(function):
|
||||
global PROGRESS_BAR_HOOK
|
||||
PROGRESS_BAR_HOOK = function
|
||||
|
||||
# Throttle settings for progress bar updates to reduce WebSocket flooding
|
||||
PROGRESS_THROTTLE_MIN_INTERVAL = 0.1 # 100ms minimum between updates
|
||||
PROGRESS_THROTTLE_MIN_PERCENT = 0.5 # 0.5% minimum progress change
|
||||
|
||||
class ProgressBar:
|
||||
def __init__(self, total, node_id=None):
|
||||
global PROGRESS_BAR_HOOK
|
||||
@@ -1104,6 +1121,8 @@ class ProgressBar:
|
||||
self.current = 0
|
||||
self.hook = PROGRESS_BAR_HOOK
|
||||
self.node_id = node_id
|
||||
self._last_update_time = 0.0
|
||||
self._last_sent_value = -1
|
||||
|
||||
def update_absolute(self, value, total=None, preview=None):
|
||||
if total is not None:
|
||||
@@ -1112,7 +1131,29 @@ class ProgressBar:
|
||||
value = self.total
|
||||
self.current = value
|
||||
if self.hook is not None:
|
||||
self.hook(self.current, self.total, preview, node_id=self.node_id)
|
||||
current_time = time.perf_counter()
|
||||
is_first = (self._last_sent_value < 0)
|
||||
is_final = (value >= self.total)
|
||||
has_preview = (preview is not None)
|
||||
|
||||
# Always send immediately for previews, first update, or final update
|
||||
if has_preview or is_first or is_final:
|
||||
self.hook(self.current, self.total, preview, node_id=self.node_id)
|
||||
self._last_update_time = current_time
|
||||
self._last_sent_value = value
|
||||
return
|
||||
|
||||
# Apply throttling for regular progress updates
|
||||
if self.total > 0:
|
||||
percent_changed = ((value - max(0, self._last_sent_value)) / self.total) * 100
|
||||
else:
|
||||
percent_changed = 100
|
||||
time_elapsed = current_time - self._last_update_time
|
||||
|
||||
if time_elapsed >= PROGRESS_THROTTLE_MIN_INTERVAL and percent_changed >= PROGRESS_THROTTLE_MIN_PERCENT:
|
||||
self.hook(self.current, self.total, preview, node_id=self.node_id)
|
||||
self._last_update_time = current_time
|
||||
self._last_sent_value = value
|
||||
|
||||
def update(self, value):
|
||||
self.update_absolute(self.current + value)
|
||||
|
||||
@@ -10,6 +10,7 @@ from ._input_impl import VideoFromFile, VideoFromComponents
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||
from . import _io_public as io
|
||||
from . import _ui_public as ui
|
||||
from . import _node_replace_public as node_replace
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||
from PIL import Image
|
||||
@@ -130,4 +131,5 @@ __all__ = [
|
||||
"IO",
|
||||
"ui",
|
||||
"UI",
|
||||
"node_replace",
|
||||
]
|
||||
|
||||
@@ -374,7 +374,7 @@ class VideoFromComponents(VideoInput):
|
||||
if audio_stream and self.__components.audio:
|
||||
waveform = self.__components.audio['waveform']
|
||||
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
|
||||
frame.sample_rate = audio_sample_rate
|
||||
frame.pts = 0
|
||||
output.mux(audio_stream.encode(frame))
|
||||
|
||||
@@ -153,7 +153,7 @@ class Input(_IO_V3):
|
||||
'''
|
||||
Base class for a V3 Input.
|
||||
'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
super().__init__()
|
||||
self.id = id
|
||||
self.display_name = display_name
|
||||
@@ -162,6 +162,7 @@ class Input(_IO_V3):
|
||||
self.lazy = lazy
|
||||
self.extra_dict = extra_dict if extra_dict is not None else {}
|
||||
self.rawLink = raw_link
|
||||
self.advanced = advanced
|
||||
|
||||
def as_dict(self):
|
||||
return prune_dict({
|
||||
@@ -170,6 +171,7 @@ class Input(_IO_V3):
|
||||
"tooltip": self.tooltip,
|
||||
"lazy": self.lazy,
|
||||
"rawLink": self.rawLink,
|
||||
"advanced": self.advanced,
|
||||
}) | prune_dict(self.extra_dict)
|
||||
|
||||
def get_io_type(self):
|
||||
@@ -184,8 +186,8 @@ class WidgetInput(Input):
|
||||
'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: Any=None,
|
||||
socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
|
||||
socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link, advanced)
|
||||
self.default = default
|
||||
self.socketless = socketless
|
||||
self.widget_type = widget_type
|
||||
@@ -242,8 +244,8 @@ class Boolean(ComfyTypeIO):
|
||||
'''Boolean input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: bool=None, label_on: str=None, label_off: str=None,
|
||||
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
||||
self.label_on = label_on
|
||||
self.label_off = label_off
|
||||
self.default: bool
|
||||
@@ -262,8 +264,8 @@ class Int(ComfyTypeIO):
|
||||
'''Integer input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.step = step
|
||||
@@ -288,8 +290,8 @@ class Float(ComfyTypeIO):
|
||||
'''Float input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.step = step
|
||||
@@ -314,8 +316,8 @@ class String(ComfyTypeIO):
|
||||
'''String input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
|
||||
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
||||
self.multiline = multiline
|
||||
self.placeholder = placeholder
|
||||
self.dynamic_prompts = dynamic_prompts
|
||||
@@ -350,12 +352,13 @@ class Combo(ComfyTypeIO):
|
||||
socketless: bool=None,
|
||||
extra_dict=None,
|
||||
raw_link: bool=None,
|
||||
advanced: bool=None,
|
||||
):
|
||||
if isinstance(options, type) and issubclass(options, Enum):
|
||||
options = [v.value for v in options]
|
||||
if isinstance(default, Enum):
|
||||
default = default.value
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link, advanced)
|
||||
self.multiselect = False
|
||||
self.options = options
|
||||
self.control_after_generate = control_after_generate
|
||||
@@ -387,8 +390,8 @@ class MultiCombo(ComfyTypeI):
|
||||
class Input(Combo.Input):
|
||||
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
|
||||
socketless: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link)
|
||||
socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced)
|
||||
self.multiselect = True
|
||||
self.placeholder = placeholder
|
||||
self.chip = chip
|
||||
@@ -421,9 +424,9 @@ class Webcam(ComfyTypeIO):
|
||||
Type = str
|
||||
def __init__(
|
||||
self, id: str, display_name: str=None, optional=False,
|
||||
tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None
|
||||
tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None
|
||||
):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link, advanced)
|
||||
|
||||
|
||||
@comfytype(io_type="MASK")
|
||||
@@ -751,7 +754,7 @@ class AnyType(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="MODEL_PATCH")
|
||||
class MODEL_PATCH(ComfyTypeIO):
|
||||
class ModelPatch(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="AUDIO_ENCODER")
|
||||
@@ -776,7 +779,7 @@ class MultiType:
|
||||
'''
|
||||
Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values.
|
||||
'''
|
||||
def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
# if id is an Input, then use that Input with overridden values
|
||||
self.input_override = None
|
||||
if isinstance(id, Input):
|
||||
@@ -789,7 +792,7 @@ class MultiType:
|
||||
# if is a widget input, make sure widget_type is set appropriately
|
||||
if isinstance(self.input_override, WidgetInput):
|
||||
self.input_override.widget_type = self.input_override.get_io_type()
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link, advanced)
|
||||
self._io_types = types
|
||||
|
||||
@property
|
||||
@@ -843,8 +846,8 @@ class MatchType(ComfyTypeIO):
|
||||
|
||||
class Input(Input):
|
||||
def __init__(self, id: str, template: MatchType.Template,
|
||||
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
|
||||
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link, advanced)
|
||||
self.template = template
|
||||
|
||||
def as_dict(self):
|
||||
@@ -997,20 +1000,38 @@ class Autogrow(ComfyTypeI):
|
||||
names = [f"{prefix}{i}" for i in range(max)]
|
||||
# need to create a new input based on the contents of input
|
||||
template_input = None
|
||||
for _, dict_input in input.items():
|
||||
# for now, get just the first value from dict_input
|
||||
template_required = True
|
||||
for _input_type, dict_input in input.items():
|
||||
# for now, get just the first value from dict_input; if not required, min can be ignored
|
||||
if len(dict_input) == 0:
|
||||
continue
|
||||
template_input = list(dict_input.values())[0]
|
||||
template_required = _input_type == "required"
|
||||
break
|
||||
if template_input is None:
|
||||
raise Exception("template_input could not be determined from required or optional; this should never happen.")
|
||||
new_dict = {}
|
||||
new_dict_added_to = False
|
||||
# first, add possible inputs into out_dict
|
||||
for i, name in enumerate(names):
|
||||
expected_id = finalize_prefix(curr_prefix, name)
|
||||
# required
|
||||
if i < min and template_required:
|
||||
out_dict["required"][expected_id] = template_input
|
||||
type_dict = new_dict.setdefault("required", {})
|
||||
# optional
|
||||
else:
|
||||
out_dict["optional"][expected_id] = template_input
|
||||
type_dict = new_dict.setdefault("optional", {})
|
||||
if expected_id in live_inputs:
|
||||
# required
|
||||
if i < min:
|
||||
type_dict = new_dict.setdefault("required", {})
|
||||
# optional
|
||||
else:
|
||||
type_dict = new_dict.setdefault("optional", {})
|
||||
# NOTE: prefix gets added in parse_class_inputs
|
||||
type_dict[name] = template_input
|
||||
new_dict_added_to = True
|
||||
# account for the edge case that all inputs are optional and no values are received
|
||||
if not new_dict_added_to:
|
||||
finalized_prefix = finalize_prefix(curr_prefix)
|
||||
out_dict["dynamic_paths"][finalized_prefix] = finalized_prefix
|
||||
out_dict["dynamic_paths_default_value"][finalized_prefix] = DynamicPathsDefaultValue.EMPTY_DICT
|
||||
parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix)
|
||||
|
||||
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
|
||||
@@ -1119,8 +1140,8 @@ class ImageCompare(ComfyTypeI):
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True):
|
||||
super().__init__(id, display_name, optional, tooltip, None, None, socketless)
|
||||
socketless: bool=True, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, None, None, socketless, None, None, None, None, advanced)
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict()
|
||||
@@ -1148,6 +1169,8 @@ class V3Data(TypedDict):
|
||||
'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.'
|
||||
dynamic_paths: dict[str, Any]
|
||||
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
|
||||
dynamic_paths_default_value: dict[str, Any]
|
||||
'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.'
|
||||
create_dynamic_tuple: bool
|
||||
'When True, the value of the dynamic input will be in the format (value, path_key).'
|
||||
|
||||
@@ -1225,6 +1248,8 @@ class NodeInfoV1:
|
||||
deprecated: bool=None
|
||||
experimental: bool=None
|
||||
api_node: bool=None
|
||||
price_badge: dict | None = None
|
||||
search_aliases: list[str]=None
|
||||
|
||||
@dataclass
|
||||
class NodeInfoV3:
|
||||
@@ -1234,11 +1259,77 @@ class NodeInfoV3:
|
||||
name: str=None
|
||||
display_name: str=None
|
||||
description: str=None
|
||||
python_module: Any = None
|
||||
category: str=None
|
||||
output_node: bool=None
|
||||
deprecated: bool=None
|
||||
experimental: bool=None
|
||||
api_node: bool=None
|
||||
price_badge: dict | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PriceBadgeDepends:
|
||||
widgets: list[str] = field(default_factory=list)
|
||||
inputs: list[str] = field(default_factory=list)
|
||||
input_groups: list[str] = field(default_factory=list)
|
||||
|
||||
def validate(self) -> None:
|
||||
if not isinstance(self.widgets, list) or any(not isinstance(x, str) for x in self.widgets):
|
||||
raise ValueError("PriceBadgeDepends.widgets must be a list[str].")
|
||||
if not isinstance(self.inputs, list) or any(not isinstance(x, str) for x in self.inputs):
|
||||
raise ValueError("PriceBadgeDepends.inputs must be a list[str].")
|
||||
if not isinstance(self.input_groups, list) or any(not isinstance(x, str) for x in self.input_groups):
|
||||
raise ValueError("PriceBadgeDepends.input_groups must be a list[str].")
|
||||
|
||||
def as_dict(self, schema_inputs: list["Input"]) -> dict[str, Any]:
|
||||
# Build lookup: widget_id -> io_type
|
||||
input_types: dict[str, str] = {}
|
||||
for inp in schema_inputs:
|
||||
all_inputs = inp.get_all()
|
||||
input_types[inp.id] = inp.get_io_type() # First input is always the parent itself
|
||||
for nested_inp in all_inputs[1:]:
|
||||
# For DynamicCombo/DynamicSlot, nested inputs are prefixed with parent ID
|
||||
# to match frontend naming convention (e.g., "should_texture.enable_pbr")
|
||||
prefixed_id = f"{inp.id}.{nested_inp.id}"
|
||||
input_types[prefixed_id] = nested_inp.get_io_type()
|
||||
|
||||
# Enrich widgets with type information, raising error for unknown widgets
|
||||
widgets_data: list[dict[str, str]] = []
|
||||
for w in self.widgets:
|
||||
if w not in input_types:
|
||||
raise ValueError(
|
||||
f"PriceBadge depends_on.widgets references unknown widget '{w}'. "
|
||||
f"Available widgets: {list(input_types.keys())}"
|
||||
)
|
||||
widgets_data.append({"name": w, "type": input_types[w]})
|
||||
|
||||
return {
|
||||
"widgets": widgets_data,
|
||||
"inputs": self.inputs,
|
||||
"input_groups": self.input_groups,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PriceBadge:
|
||||
expr: str
|
||||
depends_on: PriceBadgeDepends = field(default_factory=PriceBadgeDepends)
|
||||
engine: str = field(default="jsonata")
|
||||
|
||||
def validate(self) -> None:
|
||||
if self.engine != "jsonata":
|
||||
raise ValueError(f"Unsupported PriceBadge.engine '{self.engine}'. Only 'jsonata' is supported.")
|
||||
if not isinstance(self.expr, str) or not self.expr.strip():
|
||||
raise ValueError("PriceBadge.expr must be a non-empty string.")
|
||||
self.depends_on.validate()
|
||||
|
||||
def as_dict(self, schema_inputs: list["Input"]) -> dict[str, Any]:
|
||||
return {
|
||||
"engine": self.engine,
|
||||
"depends_on": self.depends_on.as_dict(schema_inputs),
|
||||
"expr": self.expr,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1256,6 +1347,8 @@ class Schema:
|
||||
hidden: list[Hidden] = field(default_factory=list)
|
||||
description: str=""
|
||||
"""Node description, shown as a tooltip when hovering over the node."""
|
||||
search_aliases: list[str] = field(default_factory=list)
|
||||
"""Alternative names for search. Useful for synonyms, abbreviations, or old names after renaming."""
|
||||
is_input_list: bool = False
|
||||
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
|
||||
|
||||
@@ -1284,6 +1377,8 @@ class Schema:
|
||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||
is_api_node: bool=False
|
||||
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
|
||||
price_badge: PriceBadge | None = None
|
||||
"""Optional client-evaluated pricing badge declaration for this node."""
|
||||
not_idempotent: bool=False
|
||||
"""Flags a node as not idempotent; when True, the node will run and not reuse the cached outputs when identical inputs are provided on a different node in the graph."""
|
||||
enable_expand: bool=False
|
||||
@@ -1314,6 +1409,8 @@ class Schema:
|
||||
input.validate()
|
||||
for output in self.outputs:
|
||||
output.validate()
|
||||
if self.price_badge is not None:
|
||||
self.price_badge.validate()
|
||||
|
||||
def finalize(self):
|
||||
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
|
||||
@@ -1387,7 +1484,9 @@ class Schema:
|
||||
deprecated=self.is_deprecated,
|
||||
experimental=self.is_experimental,
|
||||
api_node=self.is_api_node,
|
||||
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes")
|
||||
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
|
||||
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
|
||||
search_aliases=self.search_aliases if self.search_aliases else None,
|
||||
)
|
||||
return info
|
||||
|
||||
@@ -1419,7 +1518,8 @@ class Schema:
|
||||
deprecated=self.is_deprecated,
|
||||
experimental=self.is_experimental,
|
||||
api_node=self.is_api_node,
|
||||
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes")
|
||||
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
|
||||
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
|
||||
)
|
||||
return info
|
||||
|
||||
@@ -1428,6 +1528,7 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
|
||||
"required": {},
|
||||
"optional": {},
|
||||
"dynamic_paths": {},
|
||||
"dynamic_paths_default_value": {},
|
||||
}
|
||||
d = d.copy()
|
||||
# ignore hidden for parsing
|
||||
@@ -1437,8 +1538,12 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
|
||||
out_dict["hidden"] = hidden
|
||||
v3_data = {}
|
||||
dynamic_paths = out_dict.pop("dynamic_paths", None)
|
||||
if dynamic_paths is not None:
|
||||
if dynamic_paths is not None and len(dynamic_paths) > 0:
|
||||
v3_data["dynamic_paths"] = dynamic_paths
|
||||
# this list is used for autogrow, in the case all inputs are optional and no values are passed
|
||||
dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None)
|
||||
if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0:
|
||||
v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value
|
||||
return out_dict, hidden, v3_data
|
||||
|
||||
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
|
||||
@@ -1475,11 +1580,16 @@ def add_to_dict_v1(i: Input, d: dict):
|
||||
def add_to_dict_v3(io: Input | Output, d: dict):
|
||||
d[io.id] = (io.get_io_type(), io.as_dict())
|
||||
|
||||
class DynamicPathsDefaultValue:
|
||||
EMPTY_DICT = "empty_dict"
|
||||
|
||||
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
paths = v3_data.get("dynamic_paths", None)
|
||||
default_value_dict = v3_data.get("dynamic_paths_default_value", {})
|
||||
if paths is None:
|
||||
return values
|
||||
values = values.copy()
|
||||
|
||||
result = {}
|
||||
|
||||
create_tuple = v3_data.get("create_dynamic_tuple", False)
|
||||
@@ -1493,6 +1603,11 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
|
||||
if is_last:
|
||||
value = values.pop(key, None)
|
||||
if value is None:
|
||||
# see if a default value was provided for this key
|
||||
default_option = default_value_dict.get(key, None)
|
||||
if default_option == DynamicPathsDefaultValue.EMPTY_DICT:
|
||||
value = {}
|
||||
if create_tuple:
|
||||
value = (value, key)
|
||||
current[p] = value
|
||||
@@ -1923,6 +2038,7 @@ __all__ = [
|
||||
"ControlNet",
|
||||
"Vae",
|
||||
"Model",
|
||||
"ModelPatch",
|
||||
"ClipVision",
|
||||
"ClipVisionOutput",
|
||||
"AudioEncoder",
|
||||
@@ -1971,4 +2087,6 @@ __all__ = [
|
||||
"add_to_dict_v3",
|
||||
"V3Data",
|
||||
"ImageCompare",
|
||||
"PriceBadgeDepends",
|
||||
"PriceBadge",
|
||||
]
|
||||
|
||||
109
comfy_api/latest/_node_replace.py
Normal file
109
comfy_api/latest/_node_replace.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
import app.node_replace_manager
|
||||
|
||||
def register_node_replacement(node_replace: NodeReplace):
|
||||
"""
|
||||
Register node replacement.
|
||||
"""
|
||||
app.node_replace_manager.register_node_replacement(node_replace)
|
||||
|
||||
|
||||
class NodeReplace:
|
||||
"""
|
||||
Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
|
||||
|
||||
Also supports assigning specific values to the input widgets of the new node.
|
||||
"""
|
||||
def __init__(self,
|
||||
new_node_id: str,
|
||||
old_node_id: str,
|
||||
old_widget_ids: list[str] | None=None,
|
||||
input_mapping: list[InputMap] | None=None,
|
||||
output_mapping: list[OutputMap] | None=None,
|
||||
):
|
||||
self.new_node_id = new_node_id
|
||||
self.old_node_id = old_node_id
|
||||
self.old_widget_ids = old_widget_ids
|
||||
self.input_mapping = input_mapping
|
||||
self.output_mapping = output_mapping
|
||||
|
||||
def as_dict(self):
|
||||
"""
|
||||
Create serializable representation of the node replacement.
|
||||
"""
|
||||
return {
|
||||
"new_node_id": self.new_node_id,
|
||||
"old_node_id": self.old_node_id,
|
||||
"old_widget_ids": self.old_widget_ids,
|
||||
"input_mapping": [m.as_dict() for m in self.input_mapping] if self.input_mapping else None,
|
||||
"output_mapping": [m.as_dict() for m in self.output_mapping] if self.output_mapping else None,
|
||||
}
|
||||
|
||||
|
||||
class InputMap:
|
||||
"""
|
||||
Map inputs of node replacement.
|
||||
|
||||
Use InputMap.OldId or InputMap.SetValue for mapping purposes.
|
||||
"""
|
||||
class _Assign:
|
||||
def __init__(self, assign_type: str):
|
||||
self.assign_type = assign_type
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"assign_type": self.assign_type,
|
||||
}
|
||||
|
||||
class OldId(_Assign):
|
||||
"""
|
||||
Connect the input of the old node with given id to new node when replacing.
|
||||
"""
|
||||
def __init__(self, old_id: str):
|
||||
super().__init__("old_id")
|
||||
self.old_id = old_id
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | {
|
||||
"old_id": self.old_id,
|
||||
}
|
||||
|
||||
class SetValue(_Assign):
|
||||
"""
|
||||
Use the given value for the input of the new node when replacing; assumes input is a widget.
|
||||
"""
|
||||
def __init__(self, value: Any):
|
||||
super().__init__("set_value")
|
||||
self.value = value
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | {
|
||||
"value": self.value,
|
||||
}
|
||||
|
||||
def __init__(self, new_id: str, assign: OldId | SetValue):
|
||||
self.new_id = new_id
|
||||
self.assign = assign
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"new_id": self.new_id,
|
||||
"assign": self.assign.as_dict(),
|
||||
}
|
||||
|
||||
|
||||
class OutputMap:
|
||||
"""
|
||||
Map outputs of node replacement via indexes, as that's how outputs are stored.
|
||||
"""
|
||||
def __init__(self, new_idx: int, old_idx: int):
|
||||
self.new_idx = new_idx
|
||||
self.old_idx = old_idx
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"new_idx": self.new_idx,
|
||||
"old_idx": self.old_idx,
|
||||
}
|
||||
1
comfy_api/latest/_node_replace_public.py
Normal file
1
comfy_api/latest/_node_replace_public.py
Normal file
@@ -0,0 +1 @@
|
||||
from ._node_replace import * # noqa: F403
|
||||
@@ -6,7 +6,7 @@ from comfy_api.latest import (
|
||||
)
|
||||
from typing import Type, TYPE_CHECKING
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401
|
||||
from comfy_api.latest import io, ui, IO, UI, ComfyExtension, node_replace #noqa: F401
|
||||
|
||||
|
||||
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
|
||||
@@ -46,4 +46,5 @@ __all__ = [
|
||||
"IO",
|
||||
"ui",
|
||||
"UI",
|
||||
"node_replace",
|
||||
]
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
# ComfyUI API Nodes
|
||||
|
||||
## Introduction
|
||||
|
||||
Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview).
|
||||
|
||||
## Development
|
||||
|
||||
While developing, you should be testing against the Staging environment. To test against staging:
|
||||
|
||||
**Install ComfyUI_frontend**
|
||||
|
||||
Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to start the frontend server. By default, it will connect to Staging authentication.
|
||||
|
||||
> **Hint:** If you use --front-end-version argument for ComfyUI, it will use production authentication.
|
||||
|
||||
```bash
|
||||
python run main.py --comfy-api-base https://stagingapi.comfy.org
|
||||
```
|
||||
|
||||
To authenticate to staging, please login and then ask one of Comfy Org team to whitelist you for access to staging.
|
||||
|
||||
API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes.
|
||||
|
||||
### Redocly Instructions
|
||||
|
||||
**Tip**
|
||||
When developing locally, use the `redocly-dev.yaml` file to generate pydantic models. This lets you use stubs for APIs that are not marked `Released` yet.
|
||||
|
||||
Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging.
|
||||
|
||||
```bash
|
||||
# Download the OpenAPI file from staging server.
|
||||
curl -o openapi.yaml https://stagingapi.comfy.org/openapi
|
||||
|
||||
# Filter out unneeded API definitions.
|
||||
npm install -g @redocly/cli
|
||||
redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly-dev.yaml --remove-unused-components
|
||||
|
||||
# Generate the pydantic datamodels for validation.
|
||||
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
|
||||
|
||||
```
|
||||
|
||||
|
||||
# Merging to Master
|
||||
|
||||
Before merging to comfyanonymous/ComfyUI master, follow these steps:
|
||||
|
||||
1. Add the "Released" tag to the ComfyUI OpenAPI yaml file for each endpoint you are using in the nodes.
|
||||
1. Make sure the ComfyUI API is deployed to prod with your changes.
|
||||
1. Run the code generation again with `redocly.yaml` and the production OpenAPI yaml file.
|
||||
|
||||
```bash
|
||||
# Download the OpenAPI file from prod server.
|
||||
curl -o openapi.yaml https://api.comfy.org/openapi
|
||||
|
||||
# Filter out unneeded API definitions.
|
||||
npm install -g @redocly/cli
|
||||
redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
|
||||
|
||||
# Generate the pydantic datamodels for validation.
|
||||
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
|
||||
|
||||
```
|
||||
61
comfy_api_nodes/apis/bria.py
Normal file
61
comfy_api_nodes/apis/bria.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class InputModerationSettings(TypedDict):
|
||||
prompt_content_moderation: bool
|
||||
visual_input_moderation: bool
|
||||
visual_output_moderation: bool
|
||||
|
||||
|
||||
class BriaEditImageRequest(BaseModel):
|
||||
instruction: str | None = Field(...)
|
||||
structured_instruction: str | None = Field(
|
||||
...,
|
||||
description="Use this instead of instruction for precise, programmatic control.",
|
||||
)
|
||||
images: list[str] = Field(
|
||||
...,
|
||||
description="Required. Publicly available URL or Base64-encoded. Must contain exactly one item.",
|
||||
)
|
||||
mask: str | None = Field(
|
||||
None,
|
||||
description="Mask image (black and white). Black areas will be preserved, white areas will be edited. "
|
||||
"If omitted, the edit applies to the entire image. "
|
||||
"The input image and the the input mask must be of the same size.",
|
||||
)
|
||||
negative_prompt: str | None = Field(None)
|
||||
guidance_scale: float = Field(...)
|
||||
model_version: str = Field(...)
|
||||
steps_num: int = Field(...)
|
||||
seed: int = Field(...)
|
||||
ip_signal: bool = Field(
|
||||
False,
|
||||
description="If true, returns a warning for potential IP content in the instruction.",
|
||||
)
|
||||
prompt_content_moderation: bool = Field(
|
||||
False, description="If true, returns 422 on instruction moderation failure."
|
||||
)
|
||||
visual_input_content_moderation: bool = Field(
|
||||
False, description="If true, returns 422 on images or mask moderation failure."
|
||||
)
|
||||
visual_output_content_moderation: bool = Field(
|
||||
False, description="If true, returns 422 on visual output moderation failure."
|
||||
)
|
||||
|
||||
|
||||
class BriaStatusResponse(BaseModel):
|
||||
request_id: str = Field(...)
|
||||
status_url: str = Field(...)
|
||||
warning: str | None = Field(None)
|
||||
|
||||
|
||||
class BriaResult(BaseModel):
|
||||
structured_prompt: str = Field(...)
|
||||
image_url: str = Field(...)
|
||||
|
||||
|
||||
class BriaResponse(BaseModel):
|
||||
status: str = Field(...)
|
||||
result: BriaResult | None = Field(None)
|
||||
@@ -65,11 +65,13 @@ class TaskImageContent(BaseModel):
|
||||
class Text2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
content: list[TaskTextContent] = Field(..., min_length=1)
|
||||
generate_audio: bool | None = Field(...)
|
||||
|
||||
|
||||
class Image2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2)
|
||||
generate_audio: bool | None = Field(...)
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
@@ -141,4 +143,9 @@ VIDEO_TASKS_EXECUTION_TIME = {
|
||||
"720p": 65,
|
||||
"1080p": 100,
|
||||
},
|
||||
"seedance-1-5-pro-251215": {
|
||||
"480p": 80,
|
||||
"720p": 100,
|
||||
"1080p": 150,
|
||||
},
|
||||
}
|
||||
292
comfy_api_nodes/apis/ideogram.py
Normal file
292
comfy_api_nodes/apis/ideogram.py
Normal file
@@ -0,0 +1,292 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field, RootModel, StrictBytes
|
||||
|
||||
|
||||
class IdeogramColorPalette1(BaseModel):
|
||||
name: str = Field(..., description='Name of the preset color palette')
|
||||
|
||||
|
||||
class Member(BaseModel):
|
||||
color: Optional[str] = Field(
|
||||
None, description='Hexadecimal color code', pattern='^#[0-9A-Fa-f]{6}$'
|
||||
)
|
||||
weight: Optional[float] = Field(
|
||||
None, description='Optional weight for the color (0-1)', ge=0.0, le=1.0
|
||||
)
|
||||
|
||||
|
||||
class IdeogramColorPalette2(BaseModel):
|
||||
members: List[Member] = Field(
|
||||
..., description='Array of color definitions with optional weights'
|
||||
)
|
||||
|
||||
|
||||
class IdeogramColorPalette(
|
||||
RootModel[Union[IdeogramColorPalette1, IdeogramColorPalette2]]
|
||||
):
|
||||
root: Union[IdeogramColorPalette1, IdeogramColorPalette2] = Field(
|
||||
...,
|
||||
description='A color palette specification that can either use a preset name or explicit color definitions with weights',
|
||||
)
|
||||
|
||||
|
||||
class ImageRequest(BaseModel):
|
||||
aspect_ratio: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional. The aspect ratio (e.g., 'ASPECT_16_9', 'ASPECT_1_1'). Cannot be used with resolution. Defaults to 'ASPECT_1_1' if unspecified.",
|
||||
)
|
||||
color_palette: Optional[Dict[str, Any]] = Field(
|
||||
None, description='Optional. Color palette object. Only for V_2, V_2_TURBO.'
|
||||
)
|
||||
magic_prompt_option: Optional[str] = Field(
|
||||
None, description="Optional. MagicPrompt usage ('AUTO', 'ON', 'OFF')."
|
||||
)
|
||||
model: str = Field(..., description="The model used (e.g., 'V_2', 'V_2A_TURBO')")
|
||||
negative_prompt: Optional[str] = Field(
|
||||
None,
|
||||
description='Optional. Description of what to exclude. Only for V_1, V_1_TURBO, V_2, V_2_TURBO.',
|
||||
)
|
||||
num_images: Optional[int] = Field(
|
||||
1,
|
||||
description='Optional. Number of images to generate (1-8). Defaults to 1.',
|
||||
ge=1,
|
||||
le=8,
|
||||
)
|
||||
prompt: str = Field(
|
||||
..., description='Required. The prompt to use to generate the image.'
|
||||
)
|
||||
resolution: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional. Resolution (e.g., 'RESOLUTION_1024_1024'). Only for model V_2. Cannot be used with aspect_ratio.",
|
||||
)
|
||||
seed: Optional[int] = Field(
|
||||
None,
|
||||
description='Optional. A number between 0 and 2147483647.',
|
||||
ge=0,
|
||||
le=2147483647,
|
||||
)
|
||||
style_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional. Style type ('AUTO', 'GENERAL', 'REALISTIC', 'DESIGN', 'RENDER_3D', 'ANIME'). Only for models V_2 and above.",
|
||||
)
|
||||
|
||||
|
||||
class IdeogramGenerateRequest(BaseModel):
|
||||
image_request: ImageRequest = Field(
|
||||
..., description='The image generation request parameters.'
|
||||
)
|
||||
|
||||
|
||||
class Datum(BaseModel):
|
||||
is_image_safe: Optional[bool] = Field(
|
||||
None, description='Indicates whether the image is considered safe.'
|
||||
)
|
||||
prompt: Optional[str] = Field(
|
||||
None, description='The prompt used to generate this image.'
|
||||
)
|
||||
resolution: Optional[str] = Field(
|
||||
None, description="The resolution of the generated image (e.g., '1024x1024')."
|
||||
)
|
||||
seed: Optional[int] = Field(
|
||||
None, description='The seed value used for this generation.'
|
||||
)
|
||||
style_type: Optional[str] = Field(
|
||||
None,
|
||||
description="The style type used for generation (e.g., 'REALISTIC', 'ANIME').",
|
||||
)
|
||||
url: Optional[str] = Field(None, description='URL to the generated image.')
|
||||
|
||||
|
||||
class IdeogramGenerateResponse(BaseModel):
|
||||
created: Optional[datetime] = Field(
|
||||
None, description='Timestamp when the generation was created.'
|
||||
)
|
||||
data: Optional[List[Datum]] = Field(
|
||||
None, description='Array of generated image information.'
|
||||
)
|
||||
|
||||
|
||||
class StyleCode(RootModel[str]):
|
||||
root: str = Field(..., pattern='^[0-9A-Fa-f]{8}$')
|
||||
|
||||
|
||||
class Datum1(BaseModel):
|
||||
is_image_safe: Optional[bool] = None
|
||||
prompt: Optional[str] = None
|
||||
resolution: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
style_type: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class IdeogramV3IdeogramResponse(BaseModel):
|
||||
created: Optional[datetime] = None
|
||||
data: Optional[List[Datum1]] = None
|
||||
|
||||
|
||||
class RenderingSpeed1(str, Enum):
|
||||
TURBO = 'TURBO'
|
||||
DEFAULT = 'DEFAULT'
|
||||
QUALITY = 'QUALITY'
|
||||
|
||||
|
||||
class IdeogramV3ReframeRequest(BaseModel):
|
||||
color_palette: Optional[Dict[str, Any]] = None
|
||||
image: Optional[StrictBytes] = None
|
||||
num_images: Optional[int] = Field(None, ge=1, le=8)
|
||||
rendering_speed: Optional[RenderingSpeed1] = None
|
||||
resolution: str
|
||||
seed: Optional[int] = Field(None, ge=0, le=2147483647)
|
||||
style_codes: Optional[List[str]] = None
|
||||
style_reference_images: Optional[List[StrictBytes]] = None
|
||||
|
||||
|
||||
class MagicPrompt(str, Enum):
|
||||
AUTO = 'AUTO'
|
||||
ON = 'ON'
|
||||
OFF = 'OFF'
|
||||
|
||||
|
||||
class StyleType(str, Enum):
|
||||
AUTO = 'AUTO'
|
||||
GENERAL = 'GENERAL'
|
||||
REALISTIC = 'REALISTIC'
|
||||
DESIGN = 'DESIGN'
|
||||
|
||||
|
||||
class IdeogramV3RemixRequest(BaseModel):
|
||||
aspect_ratio: Optional[str] = None
|
||||
color_palette: Optional[Dict[str, Any]] = None
|
||||
image: Optional[StrictBytes] = None
|
||||
image_weight: Optional[int] = Field(50, ge=1, le=100)
|
||||
magic_prompt: Optional[MagicPrompt] = None
|
||||
negative_prompt: Optional[str] = None
|
||||
num_images: Optional[int] = Field(None, ge=1, le=8)
|
||||
prompt: str
|
||||
rendering_speed: Optional[RenderingSpeed1] = None
|
||||
resolution: Optional[str] = None
|
||||
seed: Optional[int] = Field(None, ge=0, le=2147483647)
|
||||
style_codes: Optional[List[str]] = None
|
||||
style_reference_images: Optional[List[StrictBytes]] = None
|
||||
style_type: Optional[StyleType] = None
|
||||
|
||||
|
||||
class IdeogramV3ReplaceBackgroundRequest(BaseModel):
|
||||
color_palette: Optional[Dict[str, Any]] = None
|
||||
image: Optional[StrictBytes] = None
|
||||
magic_prompt: Optional[MagicPrompt] = None
|
||||
num_images: Optional[int] = Field(None, ge=1, le=8)
|
||||
prompt: str
|
||||
rendering_speed: Optional[RenderingSpeed1] = None
|
||||
seed: Optional[int] = Field(None, ge=0, le=2147483647)
|
||||
style_codes: Optional[List[str]] = None
|
||||
style_reference_images: Optional[List[StrictBytes]] = None
|
||||
|
||||
|
||||
class ColorPalette(BaseModel):
|
||||
name: str = Field(..., description='Name of the color palette', examples=['PASTEL'])
|
||||
|
||||
|
||||
class MagicPrompt2(str, Enum):
|
||||
ON = 'ON'
|
||||
OFF = 'OFF'
|
||||
|
||||
|
||||
class StyleType1(str, Enum):
|
||||
AUTO = 'AUTO'
|
||||
GENERAL = 'GENERAL'
|
||||
REALISTIC = 'REALISTIC'
|
||||
DESIGN = 'DESIGN'
|
||||
FICTION = 'FICTION'
|
||||
|
||||
|
||||
class RenderingSpeed(str, Enum):
|
||||
DEFAULT = 'DEFAULT'
|
||||
TURBO = 'TURBO'
|
||||
QUALITY = 'QUALITY'
|
||||
|
||||
|
||||
class IdeogramV3EditRequest(BaseModel):
|
||||
color_palette: Optional[IdeogramColorPalette] = None
|
||||
image: Optional[StrictBytes] = Field(
|
||||
None,
|
||||
description='The image being edited (max size 10MB); only JPEG, WebP and PNG formats are supported at this time.',
|
||||
)
|
||||
magic_prompt: Optional[str] = Field(
|
||||
None,
|
||||
description='Determine if MagicPrompt should be used in generating the request or not.',
|
||||
)
|
||||
mask: Optional[StrictBytes] = Field(
|
||||
None,
|
||||
description='A black and white image of the same size as the image being edited (max size 10MB). Black regions in the mask should match up with the regions of the image that you would like to edit; only JPEG, WebP and PNG formats are supported at this time.',
|
||||
)
|
||||
num_images: Optional[int] = Field(
|
||||
None, description='The number of images to generate.'
|
||||
)
|
||||
prompt: str = Field(
|
||||
..., description='The prompt used to describe the edited result.'
|
||||
)
|
||||
rendering_speed: RenderingSpeed
|
||||
seed: Optional[int] = Field(
|
||||
None, description='Random seed. Set for reproducible generation.'
|
||||
)
|
||||
style_codes: Optional[List[StyleCode]] = Field(
|
||||
None,
|
||||
description='A list of 8 character hexadecimal codes representing the style of the image. Cannot be used in conjunction with style_reference_images or style_type.',
|
||||
)
|
||||
style_reference_images: Optional[List[StrictBytes]] = Field(
|
||||
None,
|
||||
description='A set of images to use as style references (maximum total size 10MB across all style references). The images should be in JPEG, PNG or WebP format.',
|
||||
)
|
||||
character_reference_images: Optional[List[str]] = Field(
|
||||
None,
|
||||
description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.'
|
||||
)
|
||||
character_reference_images_mask: Optional[List[str]] = Field(
|
||||
None,
|
||||
description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.'
|
||||
)
|
||||
|
||||
|
||||
class IdeogramV3Request(BaseModel):
|
||||
aspect_ratio: Optional[str] = Field(
|
||||
None, description='Aspect ratio in format WxH', examples=['1x3']
|
||||
)
|
||||
color_palette: Optional[ColorPalette] = None
|
||||
magic_prompt: Optional[MagicPrompt2] = Field(
|
||||
None, description='Whether to enable magic prompt enhancement'
|
||||
)
|
||||
negative_prompt: Optional[str] = Field(
|
||||
None, description='Text prompt specifying what to avoid in the generation'
|
||||
)
|
||||
num_images: Optional[int] = Field(
|
||||
None, description='Number of images to generate', ge=1
|
||||
)
|
||||
prompt: str = Field(..., description='The text prompt for image generation')
|
||||
rendering_speed: RenderingSpeed
|
||||
resolution: Optional[str] = Field(
|
||||
None, description='Image resolution in format WxH', examples=['1280x800']
|
||||
)
|
||||
seed: Optional[int] = Field(
|
||||
None, description='Seed value for reproducible generation'
|
||||
)
|
||||
style_codes: Optional[List[StyleCode]] = Field(
|
||||
None, description='Array of style codes in hexadecimal format'
|
||||
)
|
||||
style_reference_images: Optional[List[str]] = Field(
|
||||
None, description='Array of reference image URLs or identifiers'
|
||||
)
|
||||
style_type: Optional[StyleType1] = Field(
|
||||
None, description='The type of style to apply'
|
||||
)
|
||||
character_reference_images: Optional[List[str]] = Field(
|
||||
None,
|
||||
description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.'
|
||||
)
|
||||
character_reference_images_mask: Optional[List[str]] = Field(
|
||||
None,
|
||||
description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.'
|
||||
)
|
||||
160
comfy_api_nodes/apis/meshy.py
Normal file
160
comfy_api_nodes/apis/meshy.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from comfy_api.latest import Input
|
||||
|
||||
|
||||
class InputShouldRemesh(TypedDict):
|
||||
should_remesh: str
|
||||
topology: str
|
||||
target_polycount: int
|
||||
|
||||
|
||||
class InputShouldTexture(TypedDict):
|
||||
should_texture: str
|
||||
enable_pbr: bool
|
||||
texture_prompt: str
|
||||
texture_image: Input.Image | None
|
||||
|
||||
|
||||
class MeshyTaskResponse(BaseModel):
|
||||
result: str = Field(...)
|
||||
|
||||
|
||||
class MeshyTextToModelRequest(BaseModel):
|
||||
mode: str = Field("preview")
|
||||
prompt: str = Field(..., max_length=600)
|
||||
art_style: str = Field(..., description="'realistic' or 'sculpture'")
|
||||
ai_model: str = Field(...)
|
||||
topology: str | None = Field(..., description="'quad' or 'triangle'")
|
||||
target_polycount: int | None = Field(..., ge=100, le=300000)
|
||||
should_remesh: bool = Field(
|
||||
True,
|
||||
description="False returns the original mesh, ignoring topology and polycount.",
|
||||
)
|
||||
symmetry_mode: str = Field(..., description="'auto', 'off' or 'on'")
|
||||
pose_mode: str = Field(...)
|
||||
seed: int = Field(...)
|
||||
moderation: bool = Field(False)
|
||||
|
||||
|
||||
class MeshyRefineTask(BaseModel):
|
||||
mode: str = Field("refine")
|
||||
preview_task_id: str = Field(...)
|
||||
enable_pbr: bool | None = Field(...)
|
||||
texture_prompt: str | None = Field(...)
|
||||
texture_image_url: str | None = Field(...)
|
||||
ai_model: str = Field(...)
|
||||
moderation: bool = Field(False)
|
||||
|
||||
|
||||
class MeshyImageToModelRequest(BaseModel):
|
||||
image_url: str = Field(...)
|
||||
ai_model: str = Field(...)
|
||||
topology: str | None = Field(..., description="'quad' or 'triangle'")
|
||||
target_polycount: int | None = Field(..., ge=100, le=300000)
|
||||
symmetry_mode: str = Field(..., description="'auto', 'off' or 'on'")
|
||||
should_remesh: bool = Field(
|
||||
True,
|
||||
description="False returns the original mesh, ignoring topology and polycount.",
|
||||
)
|
||||
should_texture: bool = Field(...)
|
||||
enable_pbr: bool | None = Field(...)
|
||||
pose_mode: str = Field(...)
|
||||
texture_prompt: str | None = Field(None, max_length=600)
|
||||
texture_image_url: str | None = Field(None)
|
||||
seed: int = Field(...)
|
||||
moderation: bool = Field(False)
|
||||
|
||||
|
||||
class MeshyMultiImageToModelRequest(BaseModel):
|
||||
image_urls: list[str] = Field(...)
|
||||
ai_model: str = Field(...)
|
||||
topology: str | None = Field(..., description="'quad' or 'triangle'")
|
||||
target_polycount: int | None = Field(..., ge=100, le=300000)
|
||||
symmetry_mode: str = Field(..., description="'auto', 'off' or 'on'")
|
||||
should_remesh: bool = Field(
|
||||
True,
|
||||
description="False returns the original mesh, ignoring topology and polycount.",
|
||||
)
|
||||
should_texture: bool = Field(...)
|
||||
enable_pbr: bool | None = Field(...)
|
||||
pose_mode: str = Field(...)
|
||||
texture_prompt: str | None = Field(None, max_length=600)
|
||||
texture_image_url: str | None = Field(None)
|
||||
seed: int = Field(...)
|
||||
moderation: bool = Field(False)
|
||||
|
||||
|
||||
class MeshyRiggingRequest(BaseModel):
|
||||
input_task_id: str = Field(...)
|
||||
height_meters: float = Field(...)
|
||||
texture_image_url: str | None = Field(...)
|
||||
|
||||
|
||||
class MeshyAnimationRequest(BaseModel):
|
||||
rig_task_id: str = Field(...)
|
||||
action_id: int = Field(...)
|
||||
|
||||
|
||||
class MeshyTextureRequest(BaseModel):
|
||||
input_task_id: str = Field(...)
|
||||
ai_model: str = Field(...)
|
||||
enable_original_uv: bool = Field(...)
|
||||
enable_pbr: bool = Field(...)
|
||||
text_style_prompt: str | None = Field(...)
|
||||
image_style_url: str | None = Field(...)
|
||||
|
||||
|
||||
class MeshyModelsUrls(BaseModel):
|
||||
glb: str = Field("")
|
||||
|
||||
|
||||
class MeshyRiggedModelsUrls(BaseModel):
|
||||
rigged_character_glb_url: str = Field("")
|
||||
|
||||
|
||||
class MeshyAnimatedModelsUrls(BaseModel):
|
||||
animation_glb_url: str = Field("")
|
||||
|
||||
|
||||
class MeshyResultTextureUrls(BaseModel):
|
||||
base_color: str = Field(...)
|
||||
metallic: str | None = Field(None)
|
||||
normal: str | None = Field(None)
|
||||
roughness: str | None = Field(None)
|
||||
|
||||
|
||||
class MeshyTaskError(BaseModel):
|
||||
message: str | None = Field(None)
|
||||
|
||||
|
||||
class MeshyModelResult(BaseModel):
|
||||
id: str = Field(...)
|
||||
type: str = Field(...)
|
||||
model_urls: MeshyModelsUrls = Field(MeshyModelsUrls())
|
||||
thumbnail_url: str = Field(...)
|
||||
video_url: str | None = Field(None)
|
||||
status: str = Field(...)
|
||||
progress: int = Field(0)
|
||||
texture_urls: list[MeshyResultTextureUrls] | None = Field([])
|
||||
task_error: MeshyTaskError | None = Field(None)
|
||||
|
||||
|
||||
class MeshyRiggedResult(BaseModel):
|
||||
id: str = Field(...)
|
||||
type: str = Field(...)
|
||||
status: str = Field(...)
|
||||
progress: int = Field(0)
|
||||
result: MeshyRiggedModelsUrls = Field(MeshyRiggedModelsUrls())
|
||||
task_error: MeshyTaskError | None = Field(None)
|
||||
|
||||
|
||||
class MeshyAnimationResult(BaseModel):
|
||||
id: str = Field(...)
|
||||
type: str = Field(...)
|
||||
status: str = Field(...)
|
||||
progress: int = Field(0)
|
||||
result: MeshyAnimatedModelsUrls = Field(MeshyAnimatedModelsUrls())
|
||||
task_error: MeshyTaskError | None = Field(None)
|
||||
152
comfy_api_nodes/apis/moonvalley.py
Normal file
152
comfy_api_nodes/apis/moonvalley.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from pydantic import BaseModel, Field, StrictBytes
|
||||
|
||||
|
||||
class MoonvalleyPromptResponse(BaseModel):
|
||||
error: Optional[Dict[str, Any]] = None
|
||||
frame_conditioning: Optional[Dict[str, Any]] = None
|
||||
id: Optional[str] = None
|
||||
inference_params: Optional[Dict[str, Any]] = None
|
||||
meta: Optional[Dict[str, Any]] = None
|
||||
model_params: Optional[Dict[str, Any]] = None
|
||||
output_url: Optional[str] = None
|
||||
prompt_text: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
|
||||
|
||||
class MoonvalleyTextToVideoInferenceParams(BaseModel):
|
||||
add_quality_guidance: Optional[bool] = Field(
|
||||
True, description='Whether to add quality guidance'
|
||||
)
|
||||
caching_coefficient: Optional[float] = Field(
|
||||
0.3, description='Caching coefficient for optimization'
|
||||
)
|
||||
caching_cooldown: Optional[int] = Field(
|
||||
3, description='Number of caching cooldown steps'
|
||||
)
|
||||
caching_warmup: Optional[int] = Field(
|
||||
3, description='Number of caching warmup steps'
|
||||
)
|
||||
clip_value: Optional[float] = Field(
|
||||
3, description='CLIP value for generation control'
|
||||
)
|
||||
conditioning_frame_index: Optional[int] = Field(
|
||||
0, description='Index of the conditioning frame'
|
||||
)
|
||||
cooldown_steps: Optional[int] = Field(
|
||||
75, description='Number of cooldown steps (calculated based on num_frames)'
|
||||
)
|
||||
fps: Optional[int] = Field(
|
||||
24, description='Frames per second of the generated video'
|
||||
)
|
||||
guidance_scale: Optional[float] = Field(
|
||||
10, description='Guidance scale for generation control'
|
||||
)
|
||||
height: Optional[int] = Field(
|
||||
1080, description='Height of the generated video in pixels'
|
||||
)
|
||||
negative_prompt: Optional[str] = Field(None, description='Negative prompt text')
|
||||
num_frames: Optional[int] = Field(64, description='Number of frames to generate')
|
||||
seed: Optional[int] = Field(
|
||||
None, description='Random seed for generation (default: random)'
|
||||
)
|
||||
shift_value: Optional[float] = Field(
|
||||
3, description='Shift value for generation control'
|
||||
)
|
||||
steps: Optional[int] = Field(80, description='Number of denoising steps')
|
||||
use_guidance_schedule: Optional[bool] = Field(
|
||||
True, description='Whether to use guidance scheduling'
|
||||
)
|
||||
use_negative_prompts: Optional[bool] = Field(
|
||||
False, description='Whether to use negative prompts'
|
||||
)
|
||||
use_timestep_transform: Optional[bool] = Field(
|
||||
True, description='Whether to use timestep transformation'
|
||||
)
|
||||
warmup_steps: Optional[int] = Field(
|
||||
0, description='Number of warmup steps (calculated based on num_frames)'
|
||||
)
|
||||
width: Optional[int] = Field(
|
||||
1920, description='Width of the generated video in pixels'
|
||||
)
|
||||
|
||||
|
||||
class MoonvalleyTextToVideoRequest(BaseModel):
|
||||
image_url: Optional[str] = None
|
||||
inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None
|
||||
prompt_text: Optional[str] = None
|
||||
webhook_url: Optional[str] = None
|
||||
|
||||
|
||||
class MoonvalleyUploadFileRequest(BaseModel):
|
||||
file: Optional[StrictBytes] = None
|
||||
|
||||
|
||||
class MoonvalleyUploadFileResponse(BaseModel):
|
||||
access_url: Optional[str] = None
|
||||
|
||||
|
||||
class MoonvalleyVideoToVideoInferenceParams(BaseModel):
|
||||
add_quality_guidance: Optional[bool] = Field(
|
||||
True, description='Whether to add quality guidance'
|
||||
)
|
||||
caching_coefficient: Optional[float] = Field(
|
||||
0.3, description='Caching coefficient for optimization'
|
||||
)
|
||||
caching_cooldown: Optional[int] = Field(
|
||||
3, description='Number of caching cooldown steps'
|
||||
)
|
||||
caching_warmup: Optional[int] = Field(
|
||||
3, description='Number of caching warmup steps'
|
||||
)
|
||||
clip_value: Optional[float] = Field(
|
||||
3, description='CLIP value for generation control'
|
||||
)
|
||||
conditioning_frame_index: Optional[int] = Field(
|
||||
0, description='Index of the conditioning frame'
|
||||
)
|
||||
cooldown_steps: Optional[int] = Field(
|
||||
36, description='Number of cooldown steps (calculated based on num_frames)'
|
||||
)
|
||||
guidance_scale: Optional[float] = Field(
|
||||
15, description='Guidance scale for generation control'
|
||||
)
|
||||
negative_prompt: Optional[str] = Field(None, description='Negative prompt text')
|
||||
seed: Optional[int] = Field(
|
||||
None, description='Random seed for generation (default: random)'
|
||||
)
|
||||
shift_value: Optional[float] = Field(
|
||||
3, description='Shift value for generation control'
|
||||
)
|
||||
steps: Optional[int] = Field(80, description='Number of denoising steps')
|
||||
use_guidance_schedule: Optional[bool] = Field(
|
||||
True, description='Whether to use guidance scheduling'
|
||||
)
|
||||
use_negative_prompts: Optional[bool] = Field(
|
||||
False, description='Whether to use negative prompts'
|
||||
)
|
||||
use_timestep_transform: Optional[bool] = Field(
|
||||
True, description='Whether to use timestep transformation'
|
||||
)
|
||||
warmup_steps: Optional[int] = Field(
|
||||
24, description='Number of warmup steps (calculated based on num_frames)'
|
||||
)
|
||||
|
||||
|
||||
class ControlType(str, Enum):
|
||||
motion_control = 'motion_control'
|
||||
pose_control = 'pose_control'
|
||||
|
||||
|
||||
class MoonvalleyVideoToVideoRequest(BaseModel):
|
||||
control_type: ControlType = Field(
|
||||
..., description='Supported types for video control'
|
||||
)
|
||||
inference_params: Optional[MoonvalleyVideoToVideoInferenceParams] = None
|
||||
prompt_text: str = Field(..., description='Describes the video to generate')
|
||||
video_url: str = Field(..., description='Url to control video')
|
||||
webhook_url: Optional[str] = Field(
|
||||
None, description='Optional webhook URL for notifications'
|
||||
)
|
||||
170
comfy_api_nodes/apis/openai.py
Normal file
170
comfy_api_nodes/apis/openai.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Datum2(BaseModel):
|
||||
b64_json: str | None = Field(None, description="Base64 encoded image data")
|
||||
revised_prompt: str | None = Field(None, description="Revised prompt")
|
||||
url: str | None = Field(None, description="URL of the image")
|
||||
|
||||
|
||||
class InputTokensDetails(BaseModel):
|
||||
image_tokens: int | None = Field(None)
|
||||
text_tokens: int | None = Field(None)
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
input_tokens: int | None = Field(None)
|
||||
input_tokens_details: InputTokensDetails | None = Field(None)
|
||||
output_tokens: int | None = Field(None)
|
||||
total_tokens: int | None = Field(None)
|
||||
|
||||
|
||||
class OpenAIImageGenerationResponse(BaseModel):
|
||||
data: list[Datum2] | None = Field(None)
|
||||
usage: Usage | None = Field(None)
|
||||
|
||||
|
||||
class OpenAIImageEditRequest(BaseModel):
|
||||
background: str | None = Field(None, description="Background transparency")
|
||||
model: str = Field(...)
|
||||
moderation: str | None = Field(None)
|
||||
n: int | None = Field(None, description="The number of images to generate")
|
||||
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
|
||||
output_format: str | None = Field(None)
|
||||
prompt: str = Field(...)
|
||||
quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
|
||||
size: str | None = Field(None, description="Size of the output image")
|
||||
|
||||
|
||||
class OpenAIImageGenerationRequest(BaseModel):
|
||||
background: str | None = Field(None, description="Background transparency")
|
||||
model: str | None = Field(None)
|
||||
moderation: str | None = Field(None)
|
||||
n: int | None = Field(
|
||||
None,
|
||||
description="The number of images to generate.",
|
||||
)
|
||||
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
|
||||
output_format: str | None = Field(None)
|
||||
prompt: str = Field(...)
|
||||
quality: str | None = Field(None, description="The quality of the generated image")
|
||||
size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
|
||||
style: str | None = Field(None, description="Style of the image (only for dall-e-3)")
|
||||
|
||||
|
||||
class ModelResponseProperties(BaseModel):
|
||||
instructions: str | None = Field(None)
|
||||
max_output_tokens: int | None = Field(None)
|
||||
model: str | None = Field(None)
|
||||
temperature: float | None = Field(1, description="Controls randomness in the response", ge=0.0, le=2.0)
|
||||
top_p: float | None = Field(
|
||||
1,
|
||||
description="Controls diversity of the response via nucleus sampling",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
truncation: str | None = Field("disabled", description="Allowed values: 'auto' or 'disabled'")
|
||||
|
||||
|
||||
class ResponseProperties(BaseModel):
|
||||
instructions: str | None = Field(None)
|
||||
max_output_tokens: int | None = Field(None)
|
||||
model: str | None = Field(None)
|
||||
previous_response_id: str | None = Field(None)
|
||||
truncation: str | None = Field("disabled", description="Allowed values: 'auto' or 'disabled'")
|
||||
|
||||
|
||||
class ResponseError(BaseModel):
|
||||
code: str = Field(...)
|
||||
message: str = Field(...)
|
||||
|
||||
|
||||
class OutputTokensDetails(BaseModel):
|
||||
reasoning_tokens: int = Field(..., description="The number of reasoning tokens.")
|
||||
|
||||
|
||||
class CachedTokensDetails(BaseModel):
|
||||
cached_tokens: int = Field(
|
||||
...,
|
||||
description="The number of tokens that were retrieved from the cache.",
|
||||
)
|
||||
|
||||
|
||||
class ResponseUsage(BaseModel):
|
||||
input_tokens: int = Field(..., description="The number of input tokens.")
|
||||
input_tokens_details: CachedTokensDetails = Field(...)
|
||||
output_tokens: int = Field(..., description="The number of output tokens.")
|
||||
output_tokens_details: OutputTokensDetails = Field(...)
|
||||
total_tokens: int = Field(..., description="The total number of tokens used.")
|
||||
|
||||
|
||||
class InputTextContent(BaseModel):
|
||||
text: str = Field(..., description="The text input to the model.")
|
||||
type: str = Field("input_text")
|
||||
|
||||
|
||||
class OutputContent(BaseModel):
|
||||
type: str = Field(..., description="The type of output content")
|
||||
text: str | None = Field(None, description="The text content")
|
||||
data: str | None = Field(None, description="Base64-encoded audio data")
|
||||
transcript: str | None = Field(None, description="Transcript of the audio")
|
||||
|
||||
|
||||
class OutputMessage(BaseModel):
|
||||
type: str = Field(..., description="The type of output item")
|
||||
content: list[OutputContent] | None = Field(None, description="The content of the message")
|
||||
role: str | None = Field(None, description="The role of the message")
|
||||
|
||||
|
||||
class OpenAIResponse(ModelResponseProperties, ResponseProperties):
|
||||
created_at: float | None = Field(
|
||||
None,
|
||||
description="Unix timestamp (in seconds) of when this Response was created.",
|
||||
)
|
||||
error: ResponseError | None = Field(None)
|
||||
id: str | None = Field(None, description="Unique identifier for this Response.")
|
||||
object: str | None = Field(None, description="The object type of this resource - always set to `response`.")
|
||||
output: list[OutputMessage] | None = Field(None)
|
||||
parallel_tool_calls: bool | None = Field(True)
|
||||
status: str | None = Field(
|
||||
None,
|
||||
description="One of `completed`, `failed`, `in_progress`, or `incomplete`.",
|
||||
)
|
||||
usage: ResponseUsage | None = Field(None)
|
||||
|
||||
|
||||
class InputImageContent(BaseModel):
|
||||
detail: str = Field(..., description="One of `high`, `low`, or `auto`. Defaults to `auto`.")
|
||||
file_id: str | None = Field(None)
|
||||
image_url: str | None = Field(None)
|
||||
type: str = Field(..., description="The type of the input item. Always `input_image`.")
|
||||
|
||||
|
||||
class InputFileContent(BaseModel):
|
||||
file_data: str | None = Field(None)
|
||||
file_id: str | None = Field(None)
|
||||
filename: str | None = Field(None, description="The name of the file to be sent to the model.")
|
||||
type: str = Field(..., description="The type of the input item. Always `input_file`.")
|
||||
|
||||
|
||||
class InputMessage(BaseModel):
|
||||
content: list[InputTextContent | InputImageContent | InputFileContent] = Field(
|
||||
...,
|
||||
description="A list of one or many input items to the model, containing different content types.",
|
||||
)
|
||||
role: str | None = Field(None)
|
||||
type: str | None = Field(None)
|
||||
|
||||
|
||||
class OpenAICreateResponse(ModelResponseProperties, ResponseProperties):
|
||||
include: str | None = Field(None)
|
||||
input: list[InputMessage] = Field(...)
|
||||
parallel_tool_calls: bool | None = Field(
|
||||
True, description="Whether to allow the model to run tool calls in parallel."
|
||||
)
|
||||
store: bool | None = Field(
|
||||
True,
|
||||
description="Whether to store the generated model response for later retrieval via API.",
|
||||
)
|
||||
stream: bool | None = Field(False)
|
||||
usage: ResponseUsage | None = Field(None)
|
||||
@@ -1,52 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Datum2(BaseModel):
|
||||
b64_json: str | None = Field(None, description="Base64 encoded image data")
|
||||
revised_prompt: str | None = Field(None, description="Revised prompt")
|
||||
url: str | None = Field(None, description="URL of the image")
|
||||
|
||||
|
||||
class InputTokensDetails(BaseModel):
|
||||
image_tokens: int | None = None
|
||||
text_tokens: int | None = None
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
input_tokens: int | None = None
|
||||
input_tokens_details: InputTokensDetails | None = None
|
||||
output_tokens: int | None = None
|
||||
total_tokens: int | None = None
|
||||
|
||||
|
||||
class OpenAIImageGenerationResponse(BaseModel):
|
||||
data: list[Datum2] | None = None
|
||||
usage: Usage | None = None
|
||||
|
||||
|
||||
class OpenAIImageEditRequest(BaseModel):
|
||||
background: str | None = Field(None, description="Background transparency")
|
||||
model: str = Field(...)
|
||||
moderation: str | None = Field(None)
|
||||
n: int | None = Field(None, description="The number of images to generate")
|
||||
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
|
||||
output_format: str | None = Field(None)
|
||||
prompt: str = Field(...)
|
||||
quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
|
||||
size: str | None = Field(None, description="Size of the output image")
|
||||
|
||||
|
||||
class OpenAIImageGenerationRequest(BaseModel):
|
||||
background: str | None = Field(None, description="Background transparency")
|
||||
model: str | None = Field(None)
|
||||
moderation: str | None = Field(None)
|
||||
n: int | None = Field(
|
||||
None,
|
||||
description="The number of images to generate.",
|
||||
)
|
||||
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
|
||||
output_format: str | None = Field(None)
|
||||
prompt: str = Field(...)
|
||||
quality: str | None = Field(None, description="The quality of the generated image")
|
||||
size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
|
||||
style: str | None = Field(None, description="Style of the image (only for dall-e-3)")
|
||||
127
comfy_api_nodes/apis/runway.py
Normal file
127
comfy_api_nodes/apis/runway.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, List, Union
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
|
||||
|
||||
class RunwayAspectRatioEnum(str, Enum):
|
||||
field_1280_720 = '1280:720'
|
||||
field_720_1280 = '720:1280'
|
||||
field_1104_832 = '1104:832'
|
||||
field_832_1104 = '832:1104'
|
||||
field_960_960 = '960:960'
|
||||
field_1584_672 = '1584:672'
|
||||
field_1280_768 = '1280:768'
|
||||
field_768_1280 = '768:1280'
|
||||
|
||||
|
||||
class Position(str, Enum):
|
||||
first = 'first'
|
||||
last = 'last'
|
||||
|
||||
|
||||
class RunwayPromptImageDetailedObject(BaseModel):
|
||||
position: Position = Field(
|
||||
...,
|
||||
description="The position of the image in the output video. 'last' is currently supported for gen3a_turbo only.",
|
||||
)
|
||||
uri: str = Field(
|
||||
..., description='A HTTPS URL or data URI containing an encoded image.'
|
||||
)
|
||||
|
||||
|
||||
class RunwayPromptImageObject(
|
||||
RootModel[Union[str, List[RunwayPromptImageDetailedObject]]]
|
||||
):
|
||||
root: Union[str, List[RunwayPromptImageDetailedObject]] = Field(
|
||||
...,
|
||||
description='Image(s) to use for the video generation. Can be a single URI or an array of image objects with positions.',
|
||||
)
|
||||
|
||||
|
||||
class RunwayModelEnum(str, Enum):
|
||||
gen4_turbo = 'gen4_turbo'
|
||||
gen3a_turbo = 'gen3a_turbo'
|
||||
|
||||
|
||||
class RunwayDurationEnum(int, Enum):
|
||||
integer_5 = 5
|
||||
integer_10 = 10
|
||||
|
||||
|
||||
class RunwayImageToVideoRequest(BaseModel):
|
||||
duration: RunwayDurationEnum
|
||||
model: RunwayModelEnum
|
||||
promptImage: RunwayPromptImageObject
|
||||
promptText: Optional[str] = Field(
|
||||
None, description='Text prompt for the generation', max_length=1000
|
||||
)
|
||||
ratio: RunwayAspectRatioEnum
|
||||
seed: int = Field(
|
||||
..., description='Random seed for generation', ge=0, le=4294967295
|
||||
)
|
||||
|
||||
|
||||
class RunwayImageToVideoResponse(BaseModel):
|
||||
id: Optional[str] = Field(None, description='Task ID')
|
||||
|
||||
|
||||
class RunwayTaskStatusEnum(str, Enum):
|
||||
SUCCEEDED = 'SUCCEEDED'
|
||||
RUNNING = 'RUNNING'
|
||||
FAILED = 'FAILED'
|
||||
PENDING = 'PENDING'
|
||||
CANCELLED = 'CANCELLED'
|
||||
THROTTLED = 'THROTTLED'
|
||||
|
||||
|
||||
class RunwayTaskStatusResponse(BaseModel):
|
||||
createdAt: datetime = Field(..., description='Task creation timestamp')
|
||||
id: str = Field(..., description='Task ID')
|
||||
output: Optional[List[str]] = Field(None, description='Array of output video URLs')
|
||||
progress: Optional[float] = Field(
|
||||
None,
|
||||
description='Float value between 0 and 1 representing the progress of the task. Only available if status is RUNNING.',
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
status: RunwayTaskStatusEnum
|
||||
|
||||
|
||||
class Model4(str, Enum):
|
||||
gen4_image = 'gen4_image'
|
||||
|
||||
|
||||
class ReferenceImage(BaseModel):
|
||||
uri: Optional[str] = Field(
|
||||
None, description='A HTTPS URL or data URI containing an encoded image'
|
||||
)
|
||||
|
||||
|
||||
class RunwayTextToImageAspectRatioEnum(str, Enum):
|
||||
field_1920_1080 = '1920:1080'
|
||||
field_1080_1920 = '1080:1920'
|
||||
field_1024_1024 = '1024:1024'
|
||||
field_1360_768 = '1360:768'
|
||||
field_1080_1080 = '1080:1080'
|
||||
field_1168_880 = '1168:880'
|
||||
field_1440_1080 = '1440:1080'
|
||||
field_1080_1440 = '1080:1440'
|
||||
field_1808_768 = '1808:768'
|
||||
field_2112_912 = '2112:912'
|
||||
|
||||
|
||||
class RunwayTextToImageRequest(BaseModel):
|
||||
model: Model4 = Field(..., description='Model to use for generation')
|
||||
promptText: str = Field(
|
||||
..., description='Text prompt for the image generation', max_length=1000
|
||||
)
|
||||
ratio: RunwayTextToImageAspectRatioEnum
|
||||
referenceImages: Optional[List[ReferenceImage]] = Field(
|
||||
None, description='Array of reference images to guide the generation'
|
||||
)
|
||||
|
||||
|
||||
class RunwayTextToImageResponse(BaseModel):
|
||||
id: Optional[str] = Field(None, description='Task ID')
|
||||
@@ -41,7 +41,7 @@ class Resolution(BaseModel):
|
||||
height: int = Field(...)
|
||||
|
||||
|
||||
class CreateCreateVideoRequestSource(BaseModel):
|
||||
class CreateVideoRequestSource(BaseModel):
|
||||
container: str = Field(...)
|
||||
size: int = Field(..., description="Size of the video file in bytes")
|
||||
duration: int = Field(..., description="Duration of the video file in seconds")
|
||||
@@ -89,7 +89,7 @@ class Overrides(BaseModel):
|
||||
|
||||
|
||||
class CreateVideoRequest(BaseModel):
|
||||
source: CreateCreateVideoRequestSource = Field(...)
|
||||
source: CreateVideoRequestSource = Field(...)
|
||||
filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...)
|
||||
output: OutputInformationVideo = Field(...)
|
||||
overrides: Overrides = Field(Overrides(isPaidDiffusion=True))
|
||||
35
comfy_api_nodes/apis/wavespeed.py
Normal file
35
comfy_api_nodes/apis/wavespeed.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SeedVR2ImageRequest(BaseModel):
|
||||
image: str = Field(...)
|
||||
target_resolution: str = Field(...)
|
||||
output_format: str = Field("png")
|
||||
enable_sync_mode: bool = Field(False)
|
||||
|
||||
|
||||
class FlashVSRRequest(BaseModel):
|
||||
target_resolution: str = Field(...)
|
||||
video: str = Field(...)
|
||||
duration: float = Field(...)
|
||||
|
||||
|
||||
class TaskCreatedDataResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
|
||||
|
||||
class TaskCreatedResponse(BaseModel):
|
||||
code: int = Field(...)
|
||||
message: str = Field(...)
|
||||
data: TaskCreatedDataResponse | None = Field(None)
|
||||
|
||||
|
||||
class TaskResultDataResponse(BaseModel):
|
||||
status: str = Field(...)
|
||||
outputs: list[str] = Field([])
|
||||
|
||||
|
||||
class TaskResultResponse(BaseModel):
|
||||
code: int = Field(...)
|
||||
message: str = Field(...)
|
||||
data: TaskResultDataResponse | None = Field(None)
|
||||
@@ -1,10 +0,0 @@
|
||||
import av
|
||||
|
||||
ver = av.__version__.split(".")
|
||||
if int(ver[0]) < 14:
|
||||
raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.")
|
||||
|
||||
if int(ver[0]) == 14 and int(ver[1]) < 2:
|
||||
raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.")
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
@@ -1,116 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic import BaseModel
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from comfy.comfy_types.node_typing import IO, InputTypeOptions
|
||||
|
||||
NodeInput = tuple[IO, InputTypeOptions]
|
||||
|
||||
|
||||
def _create_base_config(field_info: FieldInfo) -> InputTypeOptions:
|
||||
config = {}
|
||||
if hasattr(field_info, "default") and field_info.default is not PydanticUndefined:
|
||||
config["default"] = field_info.default
|
||||
if hasattr(field_info, "description") and field_info.description is not None:
|
||||
config["tooltip"] = field_info.description
|
||||
return config
|
||||
|
||||
|
||||
def _get_number_constraints_config(field_info: FieldInfo) -> dict:
|
||||
config = {}
|
||||
if hasattr(field_info, "metadata"):
|
||||
metadata = field_info.metadata
|
||||
for constraint in metadata:
|
||||
if hasattr(constraint, "ge"):
|
||||
config["min"] = constraint.ge
|
||||
if hasattr(constraint, "le"):
|
||||
config["max"] = constraint.le
|
||||
if hasattr(constraint, "multiple_of"):
|
||||
config["step"] = constraint.multiple_of
|
||||
return config
|
||||
|
||||
|
||||
def _model_field_to_image_input(field_info: FieldInfo, **kwargs) -> NodeInput:
|
||||
return IO.IMAGE, {
|
||||
**_create_base_config(field_info),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
def _model_field_to_string_input(field_info: FieldInfo, **kwargs) -> NodeInput:
|
||||
return IO.STRING, {
|
||||
**_create_base_config(field_info),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
def _model_field_to_float_input(field_info: FieldInfo, **kwargs) -> NodeInput:
|
||||
return IO.FLOAT, {
|
||||
**_create_base_config(field_info),
|
||||
**_get_number_constraints_config(field_info),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
def _model_field_to_int_input(field_info: FieldInfo, **kwargs) -> NodeInput:
|
||||
return IO.INT, {
|
||||
**_create_base_config(field_info),
|
||||
**_get_number_constraints_config(field_info),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
def _model_field_to_combo_input(
|
||||
field_info: FieldInfo, enum_type: type[Enum] = None, **kwargs
|
||||
) -> NodeInput:
|
||||
combo_config = {}
|
||||
if enum_type is not None:
|
||||
combo_config["options"] = [option.value for option in enum_type]
|
||||
combo_config = {
|
||||
**combo_config,
|
||||
**_create_base_config(field_info),
|
||||
**kwargs,
|
||||
}
|
||||
return IO.COMBO, combo_config
|
||||
|
||||
|
||||
def model_field_to_node_input(
|
||||
input_type: IO, base_model: type[BaseModel], field_name: str, **kwargs
|
||||
) -> NodeInput:
|
||||
"""
|
||||
Maps a field from a Pydantic model to a Comfy node input.
|
||||
|
||||
Args:
|
||||
input_type: The type of the input.
|
||||
base_model: The Pydantic model to map the field from.
|
||||
field_name: The name of the field to map.
|
||||
**kwargs: Additional key/values to include in the input options.
|
||||
|
||||
Note:
|
||||
For combo inputs, pass an `Enum` to the `enum_type` keyword argument to populate the options automatically.
|
||||
|
||||
Example:
|
||||
>>> model_field_to_node_input(IO.STRING, MyModel, "my_field", multiline=True)
|
||||
>>> model_field_to_node_input(IO.COMBO, MyModel, "my_field", enum_type=MyEnum)
|
||||
>>> model_field_to_node_input(IO.FLOAT, MyModel, "my_field", slider=True)
|
||||
"""
|
||||
field_info: FieldInfo = base_model.model_fields[field_name]
|
||||
result: NodeInput
|
||||
|
||||
if input_type == IO.IMAGE:
|
||||
result = _model_field_to_image_input(field_info, **kwargs)
|
||||
elif input_type == IO.STRING:
|
||||
result = _model_field_to_string_input(field_info, **kwargs)
|
||||
elif input_type == IO.FLOAT:
|
||||
result = _model_field_to_float_input(field_info, **kwargs)
|
||||
elif input_type == IO.INT:
|
||||
result = _model_field_to_int_input(field_info, **kwargs)
|
||||
elif input_type == IO.COMBO:
|
||||
result = _model_field_to_combo_input(field_info, **kwargs)
|
||||
else:
|
||||
message = f"Invalid input type: {input_type}"
|
||||
raise ValueError(message)
|
||||
|
||||
return result
|
||||
@@ -3,7 +3,7 @@ from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bfl_api import (
|
||||
from comfy_api_nodes.apis.bfl import (
|
||||
BFLFluxExpandImageRequest,
|
||||
BFLFluxFillImageRequest,
|
||||
BFLFluxKontextProGenerateRequest,
|
||||
@@ -97,6 +97,9 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.06}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -352,6 +355,9 @@ class FluxProExpandNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.05}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -458,6 +464,9 @@ class FluxProFillNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.05}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -511,6 +520,21 @@ class Flux2ProImageNode(IO.ComfyNode):
|
||||
NODE_ID = "Flux2ProImageNode"
|
||||
DISPLAY_NAME = "Flux.2 [pro] Image"
|
||||
API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate"
|
||||
PRICE_BADGE_EXPR = """
|
||||
(
|
||||
$MP := 1024 * 1024;
|
||||
$outMP := $max([1, $floor(((widgets.width * widgets.height) + $MP - 1) / $MP)]);
|
||||
$outputCost := 0.03 + 0.015 * ($outMP - 1);
|
||||
inputs.images.connected
|
||||
? {
|
||||
"type":"range_usd",
|
||||
"min_usd": $outputCost + 0.015,
|
||||
"max_usd": $outputCost + 0.12,
|
||||
"format": { "approximate": true }
|
||||
}
|
||||
: {"type":"usd","usd": $outputCost}
|
||||
)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
@@ -563,6 +587,10 @@ class Flux2ProImageNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["width", "height"], inputs=["images"]),
|
||||
expr=cls.PRICE_BADGE_EXPR,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -623,6 +651,22 @@ class Flux2MaxImageNode(Flux2ProImageNode):
|
||||
NODE_ID = "Flux2MaxImageNode"
|
||||
DISPLAY_NAME = "Flux.2 [max] Image"
|
||||
API_ENDPOINT = "/proxy/bfl/flux-2-max/generate"
|
||||
PRICE_BADGE_EXPR = """
|
||||
(
|
||||
$MP := 1024 * 1024;
|
||||
$outMP := $max([1, $floor(((widgets.width * widgets.height) + $MP - 1) / $MP)]);
|
||||
$outputCost := 0.07 + 0.03 * ($outMP - 1);
|
||||
|
||||
inputs.images.connected
|
||||
? {
|
||||
"type":"range_usd",
|
||||
"min_usd": $outputCost + 0.03,
|
||||
"max_usd": $outputCost + 0.24,
|
||||
"format": { "approximate": true }
|
||||
}
|
||||
: {"type":"usd","usd": $outputCost}
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
class BFLExtension(ComfyExtension):
|
||||
|
||||
198
comfy_api_nodes/nodes_bria.py
Normal file
198
comfy_api_nodes/nodes_bria.py
Normal file
@@ -0,0 +1,198 @@
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bria import (
|
||||
BriaEditImageRequest,
|
||||
BriaResponse,
|
||||
BriaStatusResponse,
|
||||
InputModerationSettings,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
convert_mask_to_image,
|
||||
download_url_to_image_tensor,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
)
|
||||
|
||||
|
||||
class BriaImageEditNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="BriaImageEditNode",
|
||||
display_name="Bria FIBO Image Edit",
|
||||
category="api node/image/Bria",
|
||||
description="Edit images using Bria latest model",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["FIBO"]),
|
||||
IO.Image.Input("image"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Instruction to edit image",
|
||||
),
|
||||
IO.String.Input("negative_prompt", multiline=True, default=""),
|
||||
IO.String.Input(
|
||||
"structured_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="A string containing the structured edit prompt in JSON format. "
|
||||
"Use this instead of usual prompt for precise, programmatic control.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=1,
|
||||
min=1,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"guidance_scale",
|
||||
default=3,
|
||||
min=3,
|
||||
max=5,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Higher value makes the image follow the prompt more closely.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=50,
|
||||
min=20,
|
||||
max=50,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"moderation",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"true",
|
||||
[
|
||||
IO.Boolean.Input(
|
||||
"prompt_content_moderation", default=False
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"visual_input_moderation", default=False
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"visual_output_moderation", default=True
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option("false", []),
|
||||
],
|
||||
tooltip="Moderation settings",
|
||||
),
|
||||
IO.Mask.Input(
|
||||
"mask",
|
||||
tooltip="If omitted, the edit applies to the entire image.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
IO.String.Output(display_name="structured_prompt"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.04}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
structured_prompt: str,
|
||||
seed: int,
|
||||
guidance_scale: float,
|
||||
steps: int,
|
||||
moderation: InputModerationSettings,
|
||||
mask: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if not prompt and not structured_prompt:
|
||||
raise ValueError(
|
||||
"One of prompt or structured_prompt is required to be non-empty."
|
||||
)
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Exactly one input image is required.")
|
||||
mask_url = None
|
||||
if mask is not None:
|
||||
mask_url = (
|
||||
await upload_images_to_comfyapi(
|
||||
cls,
|
||||
convert_mask_to_image(mask),
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading mask",
|
||||
)
|
||||
)[0]
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"),
|
||||
data=BriaEditImageRequest(
|
||||
instruction=prompt if prompt else None,
|
||||
structured_instruction=structured_prompt if structured_prompt else None,
|
||||
images=await upload_images_to_comfyapi(
|
||||
cls,
|
||||
image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading image",
|
||||
),
|
||||
mask=mask_url,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
guidance_scale=guidance_scale,
|
||||
seed=seed,
|
||||
model_version=model,
|
||||
steps_num=steps,
|
||||
prompt_content_moderation=moderation.get(
|
||||
"prompt_content_moderation", False
|
||||
),
|
||||
visual_input_content_moderation=moderation.get(
|
||||
"visual_input_moderation", False
|
||||
),
|
||||
visual_output_content_moderation=moderation.get(
|
||||
"visual_output_moderation", False
|
||||
),
|
||||
),
|
||||
response_model=BriaStatusResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
response_model=BriaResponse,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_image_tensor(response.result.image_url),
|
||||
response.result.structured_prompt,
|
||||
)
|
||||
|
||||
|
||||
class BriaExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
BriaImageEditNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> BriaExtension:
|
||||
return BriaExtension()
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bytedance_api import (
|
||||
from comfy_api_nodes.apis.bytedance import (
|
||||
RECOMMENDED_PRESETS,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4,
|
||||
VIDEO_TASKS_EXECUTION_TIME,
|
||||
@@ -126,6 +126,9 @@ class ByteDanceImageNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.03}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -367,6 +370,19 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$price := $contains(widgets.model, "seedream-4-5-251128") ? 0.04 : 0.03;
|
||||
{
|
||||
"type":"usd",
|
||||
"usd": $price,
|
||||
"format": { "suffix":" x images/Run", "approximate": true }
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -461,7 +477,12 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
|
||||
options=[
|
||||
"seedance-1-5-pro-251215",
|
||||
"seedance-1-0-pro-250528",
|
||||
"seedance-1-0-lite-t2v-250428",
|
||||
"seedance-1-0-pro-fast-251015",
|
||||
],
|
||||
default="seedance-1-0-pro-fast-251015",
|
||||
),
|
||||
IO.String.Input(
|
||||
@@ -512,6 +533,12 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=False,
|
||||
tooltip="This parameter is ignored for any model except seedance-1-5-pro.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
@@ -522,6 +549,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=PRICE_BADGE_VIDEO,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -535,7 +563,10 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
||||
seed: int,
|
||||
camera_fixed: bool,
|
||||
watermark: bool,
|
||||
generate_audio: bool = False,
|
||||
) -> IO.NodeOutput:
|
||||
if model == "seedance-1-5-pro-251215" and duration < 4:
|
||||
raise ValueError("Minimum supported duration for Seedance 1.5 Pro is 4 seconds.")
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
|
||||
|
||||
@@ -550,7 +581,11 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
||||
)
|
||||
return await process_video_task(
|
||||
cls,
|
||||
payload=Text2VideoTaskCreationRequest(model=model, content=[TaskTextContent(text=prompt)]),
|
||||
payload=Text2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
content=[TaskTextContent(text=prompt)],
|
||||
generate_audio=generate_audio if model == "seedance-1-5-pro-251215" else None,
|
||||
),
|
||||
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
|
||||
)
|
||||
|
||||
@@ -567,7 +602,12 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
|
||||
options=[
|
||||
"seedance-1-5-pro-251215",
|
||||
"seedance-1-0-pro-250528",
|
||||
"seedance-1-0-lite-i2v-250428",
|
||||
"seedance-1-0-pro-fast-251015",
|
||||
],
|
||||
default="seedance-1-0-pro-fast-251015",
|
||||
),
|
||||
IO.String.Input(
|
||||
@@ -622,6 +662,12 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=False,
|
||||
tooltip="This parameter is ignored for any model except seedance-1-5-pro.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
@@ -632,6 +678,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=PRICE_BADGE_VIDEO,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -646,7 +693,10 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
seed: int,
|
||||
camera_fixed: bool,
|
||||
watermark: bool,
|
||||
generate_audio: bool = False,
|
||||
) -> IO.NodeOutput:
|
||||
if model == "seedance-1-5-pro-251215" and duration < 4:
|
||||
raise ValueError("Minimum supported duration for Seedance 1.5 Pro is 4 seconds.")
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
|
||||
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
|
||||
@@ -668,6 +718,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
payload=Image2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
content=[TaskTextContent(text=prompt), TaskImageContent(image_url=TaskImageContentUrl(url=image_url))],
|
||||
generate_audio=generate_audio if model == "seedance-1-5-pro-251215" else None,
|
||||
),
|
||||
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
|
||||
)
|
||||
@@ -685,7 +736,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
|
||||
options=["seedance-1-5-pro-251215", "seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
|
||||
default="seedance-1-0-lite-i2v-250428",
|
||||
),
|
||||
IO.String.Input(
|
||||
@@ -744,6 +795,12 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=False,
|
||||
tooltip="This parameter is ignored for any model except seedance-1-5-pro.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
@@ -754,6 +811,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=PRICE_BADGE_VIDEO,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -769,7 +827,10 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
seed: int,
|
||||
camera_fixed: bool,
|
||||
watermark: bool,
|
||||
generate_audio: bool = False,
|
||||
) -> IO.NodeOutput:
|
||||
if model == "seedance-1-5-pro-251215" and duration < 4:
|
||||
raise ValueError("Minimum supported duration for Seedance 1.5 Pro is 4 seconds.")
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
|
||||
for i in (first_frame, last_frame):
|
||||
@@ -802,6 +863,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[0])), role="first_frame"),
|
||||
TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[1])), role="last_frame"),
|
||||
],
|
||||
generate_audio=generate_audio if model == "seedance-1-5-pro-251215" else None,
|
||||
),
|
||||
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
|
||||
)
|
||||
@@ -877,6 +939,41 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$priceByModel := {
|
||||
"seedance-1-0-pro": {
|
||||
"480p":[0.23,0.24],
|
||||
"720p":[0.51,0.56]
|
||||
},
|
||||
"seedance-1-0-lite": {
|
||||
"480p":[0.17,0.18],
|
||||
"720p":[0.37,0.41]
|
||||
}
|
||||
};
|
||||
$model := widgets.model;
|
||||
$modelKey :=
|
||||
$contains($model, "seedance-1-0-pro") ? "seedance-1-0-pro" :
|
||||
"seedance-1-0-lite";
|
||||
$resolution := widgets.resolution;
|
||||
$resKey :=
|
||||
$contains($resolution, "720") ? "720p" :
|
||||
"480p";
|
||||
$modelPrices := $lookup($priceByModel, $modelKey);
|
||||
$baseRange := $lookup($modelPrices, $resKey);
|
||||
$min10s := $baseRange[0];
|
||||
$max10s := $baseRange[1];
|
||||
$scale := widgets.duration / 10;
|
||||
$minCost := $min10s * $scale;
|
||||
$maxCost := $max10s * $scale;
|
||||
($minCost = $maxCost)
|
||||
? {"type":"usd","usd": $minCost}
|
||||
: {"type":"range_usd","min_usd": $minCost, "max_usd": $maxCost}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -946,6 +1043,59 @@ def raise_if_text_params(prompt: str, text_params: list[str]) -> None:
|
||||
)
|
||||
|
||||
|
||||
PRICE_BADGE_VIDEO = IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution", "generate_audio"]),
|
||||
expr="""
|
||||
(
|
||||
$priceByModel := {
|
||||
"seedance-1-5-pro": {
|
||||
"480p":[0.12,0.12],
|
||||
"720p":[0.26,0.26],
|
||||
"1080p":[0.58,0.59]
|
||||
},
|
||||
"seedance-1-0-pro": {
|
||||
"480p":[0.23,0.24],
|
||||
"720p":[0.51,0.56],
|
||||
"1080p":[1.18,1.22]
|
||||
},
|
||||
"seedance-1-0-pro-fast": {
|
||||
"480p":[0.09,0.1],
|
||||
"720p":[0.21,0.23],
|
||||
"1080p":[0.47,0.49]
|
||||
},
|
||||
"seedance-1-0-lite": {
|
||||
"480p":[0.17,0.18],
|
||||
"720p":[0.37,0.41],
|
||||
"1080p":[0.85,0.88]
|
||||
}
|
||||
};
|
||||
$model := widgets.model;
|
||||
$modelKey :=
|
||||
$contains($model, "seedance-1-5-pro") ? "seedance-1-5-pro" :
|
||||
$contains($model, "seedance-1-0-pro-fast") ? "seedance-1-0-pro-fast" :
|
||||
$contains($model, "seedance-1-0-pro") ? "seedance-1-0-pro" :
|
||||
"seedance-1-0-lite";
|
||||
$resolution := widgets.resolution;
|
||||
$resKey :=
|
||||
$contains($resolution, "1080") ? "1080p" :
|
||||
$contains($resolution, "720") ? "720p" :
|
||||
"480p";
|
||||
$modelPrices := $lookup($priceByModel, $modelKey);
|
||||
$baseRange := $lookup($modelPrices, $resKey);
|
||||
$min10s := $baseRange[0];
|
||||
$max10s := $baseRange[1];
|
||||
$scale := widgets.duration / 10;
|
||||
$audioMultiplier := ($modelKey = "seedance-1-5-pro" and widgets.generate_audio) ? 2 : 1;
|
||||
$minCost := $min10s * $scale * $audioMultiplier;
|
||||
$maxCost := $max10s * $scale * $audioMultiplier;
|
||||
($minCost = $maxCost)
|
||||
? {"type":"usd","usd": $minCost, "format": { "approximate": true }}
|
||||
: {"type":"range_usd","min_usd": $minCost, "max_usd": $maxCost, "format": { "approximate": true }}
|
||||
)
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class ByteDanceExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
|
||||
@@ -14,7 +14,7 @@ from typing_extensions import override
|
||||
|
||||
import folder_paths
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
from comfy_api_nodes.apis.gemini_api import (
|
||||
from comfy_api_nodes.apis.gemini import (
|
||||
GeminiContent,
|
||||
GeminiFileData,
|
||||
GeminiGenerateContentRequest,
|
||||
@@ -130,7 +130,7 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
|
||||
Returns:
|
||||
List of response parts matching the requested type.
|
||||
"""
|
||||
if response.candidates is None:
|
||||
if not response.candidates:
|
||||
if response.promptFeedback and response.promptFeedback.blockReason:
|
||||
feedback = response.promptFeedback
|
||||
raise ValueError(
|
||||
@@ -141,14 +141,24 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
|
||||
"try changing it to `IMAGE+TEXT` to view the model's reasoning and understand why image generation failed."
|
||||
)
|
||||
parts = []
|
||||
for part in response.candidates[0].content.parts:
|
||||
if part_type == "text" and part.text:
|
||||
parts.append(part)
|
||||
elif part.inlineData and part.inlineData.mimeType == part_type:
|
||||
parts.append(part)
|
||||
elif part.fileData and part.fileData.mimeType == part_type:
|
||||
parts.append(part)
|
||||
# Skip parts that don't match the requested type
|
||||
blocked_reasons = []
|
||||
for candidate in response.candidates:
|
||||
if candidate.finishReason and candidate.finishReason.upper() == "IMAGE_PROHIBITED_CONTENT":
|
||||
blocked_reasons.append(candidate.finishReason)
|
||||
continue
|
||||
if candidate.content is None or candidate.content.parts is None:
|
||||
continue
|
||||
for part in candidate.content.parts:
|
||||
if part_type == "text" and part.text:
|
||||
parts.append(part)
|
||||
elif part.inlineData and part.inlineData.mimeType == part_type:
|
||||
parts.append(part)
|
||||
elif part.fileData and part.fileData.mimeType == part_type:
|
||||
parts.append(part)
|
||||
|
||||
if not parts and blocked_reasons:
|
||||
raise ValueError(f"Gemini API blocked the request. Reasons: {blocked_reasons}")
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
@@ -309,6 +319,30 @@ class GeminiNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$contains($m, "gemini-2.5-flash") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.0003, 0.0025],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens"}
|
||||
}
|
||||
: $contains($m, "gemini-2.5-pro") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.00125, 0.01],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gemini-3-pro-preview") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.002, 0.012],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: {"type":"text", "text":"Token-based"}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -570,6 +604,9 @@ class GeminiImage(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.039,"format":{"suffix":"/Image (1K)","approximate":true}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -700,6 +737,19 @@ class GeminiImage2(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$r := widgets.resolution;
|
||||
($contains($r,"1k") or $contains($r,"2k"))
|
||||
? {"type":"usd","usd":0.134,"format":{"suffix":"/Image","approximate":true}}
|
||||
: $contains($r,"4k")
|
||||
? {"type":"usd","usd":0.24,"format":{"suffix":"/Image","approximate":true}}
|
||||
: {"type":"text","text":"Token-based"}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -4,7 +4,7 @@ from comfy_api.latest import IO, ComfyExtension
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
from comfy_api_nodes.apis import (
|
||||
from comfy_api_nodes.apis.ideogram import (
|
||||
IdeogramGenerateRequest,
|
||||
IdeogramGenerateResponse,
|
||||
ImageRequest,
|
||||
@@ -236,7 +236,6 @@ class IdeogramV1(IO.ComfyNode):
|
||||
display_name="Ideogram V1",
|
||||
category="api node/image/Ideogram",
|
||||
description="Generates images using the Ideogram V1 model.",
|
||||
is_api_node=True,
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@@ -298,6 +297,17 @@ class IdeogramV1(IO.ComfyNode):
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]),
|
||||
expr="""
|
||||
(
|
||||
$n := widgets.num_images;
|
||||
$base := (widgets.turbo = true) ? 0.0286 : 0.0858;
|
||||
{"type":"usd","usd": $round($base * $n, 2)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -351,7 +361,6 @@ class IdeogramV2(IO.ComfyNode):
|
||||
display_name="Ideogram V2",
|
||||
category="api node/image/Ideogram",
|
||||
description="Generates images using the Ideogram V2 model.",
|
||||
is_api_node=True,
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@@ -436,6 +445,17 @@ class IdeogramV2(IO.ComfyNode):
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]),
|
||||
expr="""
|
||||
(
|
||||
$n := widgets.num_images;
|
||||
$base := (widgets.turbo = true) ? 0.0715 : 0.1144;
|
||||
{"type":"usd","usd": $round($base * $n, 2)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -506,7 +526,6 @@ class IdeogramV3(IO.ComfyNode):
|
||||
category="api node/image/Ideogram",
|
||||
description="Generates images using the Ideogram V3 model. "
|
||||
"Supports both regular image generation from text prompts and image editing with mask.",
|
||||
is_api_node=True,
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@@ -591,6 +610,23 @@ class IdeogramV3(IO.ComfyNode):
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["rendering_speed", "num_images"], inputs=["character_image"]),
|
||||
expr="""
|
||||
(
|
||||
$n := widgets.num_images;
|
||||
$speed := widgets.rendering_speed;
|
||||
$hasChar := inputs.character_image.connected;
|
||||
$base :=
|
||||
$contains($speed,"quality") ? ($hasChar ? 0.286 : 0.1287) :
|
||||
$contains($speed,"default") ? ($hasChar ? 0.2145 : 0.0858) :
|
||||
$contains($speed,"turbo") ? ($hasChar ? 0.143 : 0.0429) :
|
||||
0.0858;
|
||||
{"type":"usd","usd": $round($base * $n, 2)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -49,7 +49,7 @@ from comfy_api_nodes.apis import (
|
||||
KlingCharacterEffectModelName,
|
||||
KlingSingleImageEffectModelName,
|
||||
)
|
||||
from comfy_api_nodes.apis.kling_api import (
|
||||
from comfy_api_nodes.apis.kling import (
|
||||
ImageToVideoWithAudioRequest,
|
||||
MotionControlRequest,
|
||||
OmniImageParamImage,
|
||||
@@ -764,6 +764,33 @@ class KlingTextToVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["mode"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.mode;
|
||||
$contains($m,"v2-5-turbo")
|
||||
? ($contains($m,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35})
|
||||
: $contains($m,"v2-1-master")
|
||||
? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4})
|
||||
: $contains($m,"v2-master")
|
||||
? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4})
|
||||
: $contains($m,"v1-6")
|
||||
? (
|
||||
$contains($m,"pro")
|
||||
? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
|
||||
: ($contains($m,"10s") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28})
|
||||
)
|
||||
: $contains($m,"v1")
|
||||
? (
|
||||
$contains($m,"pro")
|
||||
? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
|
||||
: ($contains($m,"10s") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14})
|
||||
)
|
||||
: {"type":"usd","usd":0.14}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -818,6 +845,16 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$mode := (widgets.resolution = "720p") ? "std" : "pro";
|
||||
$rates := {"std": 0.084, "pro": 0.112};
|
||||
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -886,6 +923,16 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$mode := (widgets.resolution = "720p") ? "std" : "pro";
|
||||
$rates := {"std": 0.084, "pro": 0.112};
|
||||
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -981,6 +1028,16 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$mode := (widgets.resolution = "720p") ? "std" : "pro";
|
||||
$rates := {"std": 0.084, "pro": 0.112};
|
||||
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1056,6 +1113,16 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$mode := (widgets.resolution = "720p") ? "std" : "pro";
|
||||
$rates := {"std": 0.126, "pro": 0.168};
|
||||
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1142,6 +1209,16 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$mode := (widgets.resolution = "720p") ? "std" : "pro";
|
||||
$rates := {"std": 0.126, "pro": 0.168};
|
||||
{"type":"usd","usd": $lookup($rates, $mode), "format":{"suffix":"/second"}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1228,6 +1305,9 @@ class OmniProImageNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.028}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1313,6 +1393,9 @@ class KlingCameraControlT2VNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.14}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1375,6 +1458,33 @@ class KlingImage2VideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["mode", "model_name", "duration"]),
|
||||
expr="""
|
||||
(
|
||||
$mode := widgets.mode;
|
||||
$model := widgets.model_name;
|
||||
$dur := widgets.duration;
|
||||
$contains($model,"v2-5-turbo")
|
||||
? ($contains($dur,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35})
|
||||
: ($contains($model,"v2-1-master") or $contains($model,"v2-master"))
|
||||
? ($contains($dur,"10") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4})
|
||||
: ($contains($model,"v2-1") or $contains($model,"v1-6") or $contains($model,"v1-5"))
|
||||
? (
|
||||
$contains($mode,"pro")
|
||||
? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
|
||||
: ($contains($dur,"10") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28})
|
||||
)
|
||||
: $contains($model,"v1")
|
||||
? (
|
||||
$contains($mode,"pro")
|
||||
? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
|
||||
: ($contains($dur,"10") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14})
|
||||
)
|
||||
: {"type":"usd","usd":0.14}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1448,6 +1558,9 @@ class KlingCameraControlI2VNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.49}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1518,6 +1631,33 @@ class KlingStartEndFrameNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["mode"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.mode;
|
||||
$contains($m,"v2-5-turbo")
|
||||
? ($contains($m,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35})
|
||||
: $contains($m,"v2-1")
|
||||
? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
|
||||
: $contains($m,"v2-master")
|
||||
? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4})
|
||||
: $contains($m,"v1-6")
|
||||
? (
|
||||
$contains($m,"pro")
|
||||
? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
|
||||
: ($contains($m,"10s") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28})
|
||||
)
|
||||
: $contains($m,"v1")
|
||||
? (
|
||||
$contains($m,"pro")
|
||||
? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
|
||||
: ($contains($m,"10s") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14})
|
||||
)
|
||||
: {"type":"usd","usd":0.14}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1583,6 +1723,9 @@ class KlingVideoExtendNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.28}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1664,6 +1807,29 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["mode", "model_name", "duration"]),
|
||||
expr="""
|
||||
(
|
||||
$mode := widgets.mode;
|
||||
$model := widgets.model_name;
|
||||
$dur := widgets.duration;
|
||||
($contains($model,"v1-6") or $contains($model,"v1-5"))
|
||||
? (
|
||||
$contains($mode,"pro")
|
||||
? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
|
||||
: ($contains($dur,"10") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28})
|
||||
)
|
||||
: $contains($model,"v1")
|
||||
? (
|
||||
$contains($mode,"pro")
|
||||
? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
|
||||
: ($contains($dur,"10") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14})
|
||||
)
|
||||
: {"type":"usd","usd":0.14}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1728,6 +1894,16 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["effect_scene"]),
|
||||
expr="""
|
||||
(
|
||||
($contains(widgets.effect_scene,"dizzydizzy") or $contains(widgets.effect_scene,"bloombloom"))
|
||||
? {"type":"usd","usd":0.49}
|
||||
: {"type":"usd","usd":0.28}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1782,6 +1958,9 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.1,"format":{"approximate":true}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1842,6 +2021,9 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.1,"format":{"approximate":true}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1892,6 +2074,9 @@ class KlingVirtualTryOnNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.7}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1991,6 +2176,19 @@ class KlingImageGenerationNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model_name", "n"], inputs=["image"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model_name;
|
||||
$base :=
|
||||
$contains($m,"kling-v1-5")
|
||||
? (inputs.image.connected ? 0.028 : 0.014)
|
||||
: ($contains($m,"kling-v1") ? 0.0035 : 0.014);
|
||||
{"type":"usd","usd": $base * widgets.n}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -2074,6 +2272,10 @@ class TextToVideoWithAudio(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration", "generate_audio"]),
|
||||
expr="""{"type":"usd","usd": 0.07 * widgets.duration * (widgets.generate_audio ? 2 : 1)}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -2138,6 +2340,10 @@ class ImageToVideoWithAudio(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration", "generate_audio"]),
|
||||
expr="""{"type":"usd","usd": 0.07 * widgets.duration * (widgets.generate_audio ? 2 : 1)}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -2218,6 +2424,15 @@ class MotionControl(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["mode"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {"std": 0.07, "pro": 0.112};
|
||||
{"type":"usd","usd": $lookup($prices, widgets.mode), "format":{"suffix":"/second"}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -28,6 +28,22 @@ class ExecuteTaskRequest(BaseModel):
|
||||
image_uri: str | None = Field(None)
|
||||
|
||||
|
||||
PRICE_BADGE = IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {
|
||||
"ltx-2 (pro)": {"1920x1080":0.06,"2560x1440":0.12,"3840x2160":0.24},
|
||||
"ltx-2 (fast)": {"1920x1080":0.04,"2560x1440":0.08,"3840x2160":0.16}
|
||||
};
|
||||
$modelPrices := $lookup($prices, $lowercase(widgets.model));
|
||||
$pps := $lookup($modelPrices, widgets.resolution);
|
||||
{"type":"usd","usd": $pps * widgets.duration}
|
||||
)
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class TextToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@@ -69,6 +85,7 @@ class TextToVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=PRICE_BADGE,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -145,6 +162,7 @@ class ImageToVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=PRICE_BADGE,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.luma_api import (
|
||||
from comfy_api_nodes.apis.luma import (
|
||||
LumaAspectRatio,
|
||||
LumaCharacterRef,
|
||||
LumaConceptChain,
|
||||
@@ -189,6 +189,19 @@ class LumaImageGenerationNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$contains($m,"photon-flash-1")
|
||||
? {"type":"usd","usd":0.0027}
|
||||
: $contains($m,"photon-1")
|
||||
? {"type":"usd","usd":0.0104}
|
||||
: {"type":"usd","usd":0.0246}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -303,6 +316,19 @@ class LumaImageModifyNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$contains($m,"photon-flash-1")
|
||||
? {"type":"usd","usd":0.0027}
|
||||
: $contains($m,"photon-1")
|
||||
? {"type":"usd","usd":0.0104}
|
||||
: {"type":"usd","usd":0.0246}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -395,6 +421,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=PRICE_BADGE_VIDEO,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -505,6 +532,8 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=PRICE_BADGE_VIDEO,
|
||||
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -568,6 +597,53 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
|
||||
return LumaKeyframes(frame0=frame0, frame1=frame1)
|
||||
|
||||
|
||||
PRICE_BADGE_VIDEO = IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "resolution", "duration"]),
|
||||
expr="""
|
||||
(
|
||||
$p := {
|
||||
"ray-flash-2": {
|
||||
"5s": {"4k":3.13,"1080p":0.79,"720p":0.34,"540p":0.2},
|
||||
"9s": {"4k":5.65,"1080p":1.42,"720p":0.61,"540p":0.36}
|
||||
},
|
||||
"ray-2": {
|
||||
"5s": {"4k":9.11,"1080p":2.27,"720p":1.02,"540p":0.57},
|
||||
"9s": {"4k":16.4,"1080p":4.1,"720p":1.83,"540p":1.03}
|
||||
}
|
||||
};
|
||||
|
||||
$m := widgets.model;
|
||||
$d := widgets.duration;
|
||||
$r := widgets.resolution;
|
||||
|
||||
$modelKey :=
|
||||
$contains($m,"ray-flash-2") ? "ray-flash-2" :
|
||||
$contains($m,"ray-2") ? "ray-2" :
|
||||
$contains($m,"ray-1-6") ? "ray-1-6" :
|
||||
"other";
|
||||
|
||||
$durKey := $contains($d,"5s") ? "5s" : $contains($d,"9s") ? "9s" : "";
|
||||
$resKey :=
|
||||
$contains($r,"4k") ? "4k" :
|
||||
$contains($r,"1080p") ? "1080p" :
|
||||
$contains($r,"720p") ? "720p" :
|
||||
$contains($r,"540p") ? "540p" : "";
|
||||
|
||||
$modelPrices := $lookup($p, $modelKey);
|
||||
$durPrices := $lookup($modelPrices, $durKey);
|
||||
$v := $lookup($durPrices, $resKey);
|
||||
|
||||
$price :=
|
||||
($modelKey = "ray-1-6") ? 0.5 :
|
||||
($modelKey = "other") ? 0.79 :
|
||||
($exists($v) ? $v : 0.79);
|
||||
|
||||
{"type":"usd","usd": $price}
|
||||
)
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class LumaExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
|
||||
790
comfy_api_nodes/nodes_meshy.py
Normal file
790
comfy_api_nodes/nodes_meshy.py
Normal file
@@ -0,0 +1,790 @@
|
||||
import os
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.meshy import (
|
||||
InputShouldRemesh,
|
||||
InputShouldTexture,
|
||||
MeshyAnimationRequest,
|
||||
MeshyAnimationResult,
|
||||
MeshyImageToModelRequest,
|
||||
MeshyModelResult,
|
||||
MeshyMultiImageToModelRequest,
|
||||
MeshyRefineTask,
|
||||
MeshyRiggedResult,
|
||||
MeshyRiggingRequest,
|
||||
MeshyTaskResponse,
|
||||
MeshyTextToModelRequest,
|
||||
MeshyTextureRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_bytesio,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
from folder_paths import get_output_directory
|
||||
|
||||
|
||||
class MeshyTextToModelNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MeshyTextToModelNode",
|
||||
display_name="Meshy: Text to Model",
|
||||
category="api node/3d/Meshy",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["latest"]),
|
||||
IO.String.Input("prompt", multiline=True, default=""),
|
||||
IO.Combo.Input("style", options=["realistic", "sculpture"]),
|
||||
IO.DynamicCombo.Input(
|
||||
"should_remesh",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"true",
|
||||
[
|
||||
IO.Combo.Input("topology", options=["triangle", "quad"]),
|
||||
IO.Int.Input(
|
||||
"target_polycount",
|
||||
default=300000,
|
||||
min=100,
|
||||
max=300000,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option("false", []),
|
||||
],
|
||||
tooltip="When set to false, returns an unprocessed triangular mesh.",
|
||||
),
|
||||
IO.Combo.Input("symmetry_mode", options=["auto", "on", "off"]),
|
||||
IO.Combo.Input(
|
||||
"pose_mode",
|
||||
options=["", "A-pose", "T-pose"],
|
||||
tooltip="Specify the pose mode for the generated model.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_output_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.8}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
style: str,
|
||||
should_remesh: InputShouldRemesh,
|
||||
symmetry_mode: str,
|
||||
pose_mode: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, field_name="prompt", min_length=1, max_length=600)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/meshy/openapi/v2/text-to-3d", method="POST"),
|
||||
response_model=MeshyTaskResponse,
|
||||
data=MeshyTextToModelRequest(
|
||||
prompt=prompt,
|
||||
art_style=style,
|
||||
ai_model=model,
|
||||
topology=should_remesh.get("topology", None),
|
||||
target_polycount=should_remesh.get("target_polycount", None),
|
||||
should_remesh=should_remesh["should_remesh"] == "true",
|
||||
symmetry_mode=symmetry_mode,
|
||||
pose_mode=pose_mode.lower(),
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"),
|
||||
response_model=MeshyModelResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
|
||||
|
||||
class MeshyRefineNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MeshyRefineNode",
|
||||
display_name="Meshy: Refine Draft Model",
|
||||
category="api node/3d/Meshy",
|
||||
description="Refine a previously created draft model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["latest"]),
|
||||
IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"),
|
||||
IO.Boolean.Input(
|
||||
"enable_pbr",
|
||||
default=False,
|
||||
tooltip="Generate PBR Maps (metallic, roughness, normal) in addition to the base color. "
|
||||
"Note: this should be set to false when using Sculpture style, "
|
||||
"as Sculpture style generates its own set of PBR maps.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"texture_prompt",
|
||||
default="",
|
||||
multiline=True,
|
||||
tooltip="Provide a text prompt to guide the texturing process. "
|
||||
"Maximum 600 characters. Cannot be used at the same time as 'texture_image'.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"texture_image",
|
||||
tooltip="Only one of 'texture_image' or 'texture_prompt' may be used at the same time.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_output_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.4}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
meshy_task_id: str,
|
||||
enable_pbr: bool,
|
||||
texture_prompt: str,
|
||||
texture_image: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if texture_prompt and texture_image is not None:
|
||||
raise ValueError("texture_prompt and texture_image cannot be used at the same time")
|
||||
texture_image_url = None
|
||||
if texture_prompt:
|
||||
validate_string(texture_prompt, field_name="texture_prompt", max_length=600)
|
||||
if texture_image is not None:
|
||||
texture_image_url = (await upload_images_to_comfyapi(cls, texture_image, wait_label="Uploading texture"))[0]
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v2/text-to-3d", method="POST"),
|
||||
response_model=MeshyTaskResponse,
|
||||
data=MeshyRefineTask(
|
||||
preview_task_id=meshy_task_id,
|
||||
enable_pbr=enable_pbr,
|
||||
texture_prompt=texture_prompt if texture_prompt else None,
|
||||
texture_image_url=texture_image_url,
|
||||
ai_model=model,
|
||||
),
|
||||
)
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"),
|
||||
response_model=MeshyModelResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
|
||||
|
||||
class MeshyImageToModelNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MeshyImageToModelNode",
|
||||
display_name="Meshy: Image to Model",
|
||||
category="api node/3d/Meshy",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["latest"]),
|
||||
IO.Image.Input("image"),
|
||||
IO.DynamicCombo.Input(
|
||||
"should_remesh",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"true",
|
||||
[
|
||||
IO.Combo.Input("topology", options=["triangle", "quad"]),
|
||||
IO.Int.Input(
|
||||
"target_polycount",
|
||||
default=300000,
|
||||
min=100,
|
||||
max=300000,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option("false", []),
|
||||
],
|
||||
tooltip="When set to false, returns an unprocessed triangular mesh.",
|
||||
),
|
||||
IO.Combo.Input("symmetry_mode", options=["auto", "on", "off"]),
|
||||
IO.DynamicCombo.Input(
|
||||
"should_texture",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"true",
|
||||
[
|
||||
IO.Boolean.Input(
|
||||
"enable_pbr",
|
||||
default=False,
|
||||
tooltip="Generate PBR Maps (metallic, roughness, normal) "
|
||||
"in addition to the base color.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"texture_prompt",
|
||||
default="",
|
||||
multiline=True,
|
||||
tooltip="Provide a text prompt to guide the texturing process. "
|
||||
"Maximum 600 characters. Cannot be used at the same time as 'texture_image'.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"texture_image",
|
||||
tooltip="Only one of 'texture_image' or 'texture_prompt' "
|
||||
"may be used at the same time.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option("false", []),
|
||||
],
|
||||
tooltip="Determines whether textures are generated. "
|
||||
"Setting it to false skips the texture phase and returns a mesh without textures.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"pose_mode",
|
||||
options=["", "A-pose", "T-pose"],
|
||||
tooltip="Specify the pose mode for the generated model.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_output_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["should_texture"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {"true": 1.2, "false": 0.8};
|
||||
{"type":"usd","usd": $lookup($prices, widgets.should_texture)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: Input.Image,
|
||||
should_remesh: InputShouldRemesh,
|
||||
symmetry_mode: str,
|
||||
should_texture: InputShouldTexture,
|
||||
pose_mode: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
texture = should_texture["should_texture"] == "true"
|
||||
texture_image_url = texture_prompt = None
|
||||
if texture:
|
||||
if should_texture["texture_prompt"] and should_texture["texture_image"] is not None:
|
||||
raise ValueError("texture_prompt and texture_image cannot be used at the same time")
|
||||
if should_texture["texture_prompt"]:
|
||||
validate_string(should_texture["texture_prompt"], field_name="texture_prompt", max_length=600)
|
||||
texture_prompt = should_texture["texture_prompt"]
|
||||
if should_texture["texture_image"] is not None:
|
||||
texture_image_url = (
|
||||
await upload_images_to_comfyapi(
|
||||
cls, should_texture["texture_image"], wait_label="Uploading texture"
|
||||
)
|
||||
)[0]
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/meshy/openapi/v1/image-to-3d", method="POST"),
|
||||
response_model=MeshyTaskResponse,
|
||||
data=MeshyImageToModelRequest(
|
||||
image_url=(await upload_images_to_comfyapi(cls, image, wait_label="Uploading base image"))[0],
|
||||
ai_model=model,
|
||||
topology=should_remesh.get("topology", None),
|
||||
target_polycount=should_remesh.get("target_polycount", None),
|
||||
symmetry_mode=symmetry_mode,
|
||||
should_remesh=should_remesh["should_remesh"] == "true",
|
||||
should_texture=texture,
|
||||
enable_pbr=should_texture.get("enable_pbr", None),
|
||||
pose_mode=pose_mode.lower(),
|
||||
texture_prompt=texture_prompt,
|
||||
texture_image_url=texture_image_url,
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/image-to-3d/{response.result}"),
|
||||
response_model=MeshyModelResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
|
||||
|
||||
class MeshyMultiImageToModelNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MeshyMultiImageToModelNode",
|
||||
display_name="Meshy: Multi-Image to Model",
|
||||
category="api node/3d/Meshy",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["latest"]),
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplatePrefix(IO.Image.Input("image"), prefix="image", min=2, max=4),
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"should_remesh",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"true",
|
||||
[
|
||||
IO.Combo.Input("topology", options=["triangle", "quad"]),
|
||||
IO.Int.Input(
|
||||
"target_polycount",
|
||||
default=300000,
|
||||
min=100,
|
||||
max=300000,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option("false", []),
|
||||
],
|
||||
tooltip="When set to false, returns an unprocessed triangular mesh.",
|
||||
),
|
||||
IO.Combo.Input("symmetry_mode", options=["auto", "on", "off"]),
|
||||
IO.DynamicCombo.Input(
|
||||
"should_texture",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"true",
|
||||
[
|
||||
IO.Boolean.Input(
|
||||
"enable_pbr",
|
||||
default=False,
|
||||
tooltip="Generate PBR Maps (metallic, roughness, normal) "
|
||||
"in addition to the base color.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"texture_prompt",
|
||||
default="",
|
||||
multiline=True,
|
||||
tooltip="Provide a text prompt to guide the texturing process. "
|
||||
"Maximum 600 characters. Cannot be used at the same time as 'texture_image'.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"texture_image",
|
||||
tooltip="Only one of 'texture_image' or 'texture_prompt' "
|
||||
"may be used at the same time.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option("false", []),
|
||||
],
|
||||
tooltip="Determines whether textures are generated. "
|
||||
"Setting it to false skips the texture phase and returns a mesh without textures.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"pose_mode",
|
||||
options=["", "A-pose", "T-pose"],
|
||||
tooltip="Specify the pose mode for the generated model.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_output_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["should_texture"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {"true": 0.6, "false": 0.2};
|
||||
{"type":"usd","usd": $lookup($prices, widgets.should_texture)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
images: IO.Autogrow.Type,
|
||||
should_remesh: InputShouldRemesh,
|
||||
symmetry_mode: str,
|
||||
should_texture: InputShouldTexture,
|
||||
pose_mode: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
texture = should_texture["should_texture"] == "true"
|
||||
texture_image_url = texture_prompt = None
|
||||
if texture:
|
||||
if should_texture["texture_prompt"] and should_texture["texture_image"] is not None:
|
||||
raise ValueError("texture_prompt and texture_image cannot be used at the same time")
|
||||
if should_texture["texture_prompt"]:
|
||||
validate_string(should_texture["texture_prompt"], field_name="texture_prompt", max_length=600)
|
||||
texture_prompt = should_texture["texture_prompt"]
|
||||
if should_texture["texture_image"] is not None:
|
||||
texture_image_url = (
|
||||
await upload_images_to_comfyapi(
|
||||
cls, should_texture["texture_image"], wait_label="Uploading texture"
|
||||
)
|
||||
)[0]
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/meshy/openapi/v1/multi-image-to-3d", method="POST"),
|
||||
response_model=MeshyTaskResponse,
|
||||
data=MeshyMultiImageToModelRequest(
|
||||
image_urls=await upload_images_to_comfyapi(
|
||||
cls, list(images.values()), wait_label="Uploading base images"
|
||||
),
|
||||
ai_model=model,
|
||||
topology=should_remesh.get("topology", None),
|
||||
target_polycount=should_remesh.get("target_polycount", None),
|
||||
symmetry_mode=symmetry_mode,
|
||||
should_remesh=should_remesh["should_remesh"] == "true",
|
||||
should_texture=texture,
|
||||
enable_pbr=should_texture.get("enable_pbr", None),
|
||||
pose_mode=pose_mode.lower(),
|
||||
texture_prompt=texture_prompt,
|
||||
texture_image_url=texture_image_url,
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/multi-image-to-3d/{response.result}"),
|
||||
response_model=MeshyModelResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
|
||||
|
||||
class MeshyRigModelNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MeshyRigModelNode",
|
||||
display_name="Meshy: Rig Model",
|
||||
category="api node/3d/Meshy",
|
||||
description="Provides a rigged character in standard formats. "
|
||||
"Auto-rigging is currently not suitable for untextured meshes, non-humanoid assets, "
|
||||
"or humanoid assets with unclear limb and body structure.",
|
||||
inputs=[
|
||||
IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"),
|
||||
IO.Float.Input(
|
||||
"height_meters",
|
||||
min=0.1,
|
||||
max=15.0,
|
||||
default=1.7,
|
||||
tooltip="The approximate height of the character model in meters. "
|
||||
"This aids in scaling and rigging accuracy.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"texture_image",
|
||||
tooltip="The model's UV-unwrapped base color texture image.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.Custom("MESHY_RIGGED_TASK_ID").Output(display_name="rig_task_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_output_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.2}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
meshy_task_id: str,
|
||||
height_meters: float,
|
||||
texture_image: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
texture_image_url = None
|
||||
if texture_image is not None:
|
||||
texture_image_url = (await upload_images_to_comfyapi(cls, texture_image, wait_label="Uploading texture"))[0]
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v1/rigging", method="POST"),
|
||||
response_model=MeshyTaskResponse,
|
||||
data=MeshyRiggingRequest(
|
||||
input_task_id=meshy_task_id,
|
||||
height_meters=height_meters,
|
||||
texture_image_url=texture_image_url,
|
||||
),
|
||||
)
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/rigging/{response.result}"),
|
||||
response_model=MeshyRiggedResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(
|
||||
result.result.rigged_character_glb_url, os.path.join(get_output_directory(), model_file)
|
||||
)
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
|
||||
|
||||
class MeshyAnimateModelNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MeshyAnimateModelNode",
|
||||
display_name="Meshy: Animate Model",
|
||||
category="api node/3d/Meshy",
|
||||
description="Apply a specific animation action to a previously rigged character.",
|
||||
inputs=[
|
||||
IO.Custom("MESHY_RIGGED_TASK_ID").Input("rig_task_id"),
|
||||
IO.Int.Input(
|
||||
"action_id",
|
||||
default=0,
|
||||
min=0,
|
||||
max=696,
|
||||
tooltip="Visit https://docs.meshy.ai/en/api/animation-library for a list of available values.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_output_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.12}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
rig_task_id: str,
|
||||
action_id: int,
|
||||
) -> IO.NodeOutput:
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v1/animations", method="POST"),
|
||||
response_model=MeshyTaskResponse,
|
||||
data=MeshyAnimationRequest(
|
||||
rig_task_id=rig_task_id,
|
||||
action_id=action_id,
|
||||
),
|
||||
)
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/animations/{response.result}"),
|
||||
response_model=MeshyAnimationResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.result.animation_glb_url, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
|
||||
|
||||
class MeshyTextureNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MeshyTextureNode",
|
||||
display_name="Meshy: Texture Model",
|
||||
category="api node/3d/Meshy",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["latest"]),
|
||||
IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"),
|
||||
IO.Boolean.Input(
|
||||
"enable_original_uv",
|
||||
default=True,
|
||||
tooltip="Use the original UV of the model instead of generating new UVs. "
|
||||
"When enabled, Meshy preserves existing textures from the uploaded model. "
|
||||
"If the model has no original UV, the quality of the output might not be as good.",
|
||||
),
|
||||
IO.Boolean.Input("pbr", default=False),
|
||||
IO.String.Input(
|
||||
"text_style_prompt",
|
||||
default="",
|
||||
multiline=True,
|
||||
tooltip="Describe your desired texture style of the object using text. Maximum 600 characters."
|
||||
"Maximum 600 characters. Cannot be used at the same time as 'image_style'.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image_style",
|
||||
optional=True,
|
||||
tooltip="A 2d image to guide the texturing process. "
|
||||
"Can not be used at the same time with 'text_style_prompt'.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="meshy_task_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_output_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.4}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
meshy_task_id: str,
|
||||
enable_original_uv: bool,
|
||||
pbr: bool,
|
||||
text_style_prompt: str,
|
||||
image_style: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if text_style_prompt and image_style is not None:
|
||||
raise ValueError("text_style_prompt and image_style cannot be used at the same time")
|
||||
if not text_style_prompt and image_style is None:
|
||||
raise ValueError("Either text_style_prompt or image_style is required")
|
||||
image_style_url = None
|
||||
if image_style is not None:
|
||||
image_style_url = (await upload_images_to_comfyapi(cls, image_style, wait_label="Uploading style"))[0]
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v1/retexture", method="POST"),
|
||||
response_model=MeshyTaskResponse,
|
||||
data=MeshyTextureRequest(
|
||||
input_task_id=meshy_task_id,
|
||||
ai_model=model,
|
||||
enable_original_uv=enable_original_uv,
|
||||
enable_pbr=pbr,
|
||||
text_style_prompt=text_style_prompt if text_style_prompt else None,
|
||||
image_style_url=image_style_url,
|
||||
),
|
||||
)
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/retexture/{response.result}"),
|
||||
response_model=MeshyModelResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
|
||||
|
||||
class MeshyExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
MeshyTextToModelNode,
|
||||
MeshyRefineNode,
|
||||
MeshyImageToModelNode,
|
||||
MeshyMultiImageToModelNode,
|
||||
MeshyRigModelNode,
|
||||
MeshyAnimateModelNode,
|
||||
MeshyTextureNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> MeshyExtension:
|
||||
return MeshyExtension()
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.minimax_api import (
|
||||
from comfy_api_nodes.apis.minimax import (
|
||||
MinimaxFileRetrieveResponse,
|
||||
MiniMaxModel,
|
||||
MinimaxTaskResultResponse,
|
||||
@@ -134,6 +134,9 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.43}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -197,6 +200,9 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.43}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -340,6 +346,20 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["resolution", "duration"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {
|
||||
"768p": {"6": 0.28, "10": 0.56},
|
||||
"1080p": {"6": 0.49}
|
||||
};
|
||||
$resPrices := $lookup($prices, $lowercase(widgets.resolution));
|
||||
$price := $lookup($resPrices, $string(widgets.duration));
|
||||
{"type":"usd","usd": $price ? $price : 0.43}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis import (
|
||||
from comfy_api_nodes.apis.moonvalley import (
|
||||
MoonvalleyPromptResponse,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
MoonvalleyTextToVideoRequest,
|
||||
@@ -233,6 +233,10 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(),
|
||||
expr="""{"type":"usd","usd": 1.5}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -351,6 +355,10 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(),
|
||||
expr="""{"type":"usd","usd": 2.25}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -471,6 +479,10 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(),
|
||||
expr="""{"type":"usd","usd": 1.5}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -10,24 +10,18 @@ from typing_extensions import override
|
||||
|
||||
import folder_paths
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis import (
|
||||
CreateModelResponseProperties,
|
||||
Detail,
|
||||
InputContent,
|
||||
from comfy_api_nodes.apis.openai import (
|
||||
InputFileContent,
|
||||
InputImageContent,
|
||||
InputMessage,
|
||||
InputMessageContentList,
|
||||
InputTextContent,
|
||||
Item,
|
||||
ModelResponseProperties,
|
||||
OpenAICreateResponse,
|
||||
OpenAIResponse,
|
||||
OutputContent,
|
||||
)
|
||||
from comfy_api_nodes.apis.openai_api import (
|
||||
OpenAIImageEditRequest,
|
||||
OpenAIImageGenerationRequest,
|
||||
OpenAIImageGenerationResponse,
|
||||
OpenAIResponse,
|
||||
OutputContent,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
@@ -160,6 +154,23 @@ class OpenAIDalle2(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["size", "n"]),
|
||||
expr="""
|
||||
(
|
||||
$size := widgets.size;
|
||||
$nRaw := widgets.n;
|
||||
$n := ($nRaw != null and $nRaw != 0) ? $nRaw : 1;
|
||||
|
||||
$base :=
|
||||
$contains($size, "256x256") ? 0.016 :
|
||||
$contains($size, "512x512") ? 0.018 :
|
||||
0.02;
|
||||
|
||||
{"type":"usd","usd": $round($base * $n, 3)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -249,7 +260,7 @@ class OpenAIDalle3(IO.ComfyNode):
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2 ** 31 - 1,
|
||||
max=2**31 - 1,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
@@ -287,6 +298,25 @@ class OpenAIDalle3(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["size", "quality"]),
|
||||
expr="""
|
||||
(
|
||||
$size := widgets.size;
|
||||
$q := widgets.quality;
|
||||
$hd := $contains($q, "hd");
|
||||
|
||||
$price :=
|
||||
$contains($size, "1024x1024")
|
||||
? ($hd ? 0.08 : 0.04)
|
||||
: (($contains($size, "1792x1024") or $contains($size, "1024x1792"))
|
||||
? ($hd ? 0.12 : 0.08)
|
||||
: 0.04);
|
||||
|
||||
{"type":"usd","usd": $price}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -334,9 +364,9 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIGPTImage1",
|
||||
display_name="OpenAI GPT Image 1",
|
||||
display_name="OpenAI GPT Image 1.5",
|
||||
category="api node/image/OpenAI",
|
||||
description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.",
|
||||
description="Generates images synchronously via OpenAI's GPT Image endpoint.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@@ -348,7 +378,7 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2 ** 31 - 1,
|
||||
max=2**31 - 1,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
@@ -399,6 +429,7 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["gpt-image-1", "gpt-image-1.5"],
|
||||
default="gpt-image-1.5",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@@ -411,6 +442,28 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["quality", "n"]),
|
||||
expr="""
|
||||
(
|
||||
$ranges := {
|
||||
"low": [0.011, 0.02],
|
||||
"medium": [0.046, 0.07],
|
||||
"high": [0.167, 0.3]
|
||||
};
|
||||
$range := $lookup($ranges, widgets.quality);
|
||||
$n := widgets.n;
|
||||
($n = 1)
|
||||
? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1]}
|
||||
: {
|
||||
"type":"range_usd",
|
||||
"min_usd": $range[0],
|
||||
"max_usd": $range[1],
|
||||
"format": { "suffix": " x " & $string($n) & "/Run" }
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -442,8 +495,8 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
files = []
|
||||
batch_size = image.shape[0]
|
||||
for i in range(batch_size):
|
||||
single_image = image[i: i + 1]
|
||||
scaled_image = downscale_image_tensor(single_image, total_pixels=2048*2048).squeeze()
|
||||
single_image = image[i : i + 1]
|
||||
scaled_image = downscale_image_tensor(single_image, total_pixels=2048 * 2048).squeeze()
|
||||
|
||||
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
@@ -465,7 +518,7 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
||||
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
|
||||
|
||||
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048*2048).squeeze()
|
||||
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048 * 2048).squeeze()
|
||||
|
||||
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
||||
mask_img = Image.fromarray(mask_np)
|
||||
@@ -566,32 +619,95 @@ class OpenAIChatNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$contains($m, "o4-mini") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.0011, 0.0044],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "o1-pro") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.15, 0.6],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "o1") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.015, 0.06],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "o3-mini") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.0011, 0.0044],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "o3") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.01, 0.04],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gpt-4o") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.0025, 0.01],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gpt-4.1-nano") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.0001, 0.0004],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gpt-4.1-mini") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.0004, 0.0016],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gpt-4.1") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.002, 0.008],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gpt-5-nano") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.00005, 0.0004],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gpt-5-mini") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.00025, 0.002],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gpt-5") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.00125, 0.01],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: {"type": "text", "text": "Token-based"}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_message_content_from_response(
|
||||
cls, response: OpenAIResponse
|
||||
) -> list[OutputContent]:
|
||||
def get_message_content_from_response(cls, response: OpenAIResponse) -> list[OutputContent]:
|
||||
"""Extract message content from the API response."""
|
||||
for output in response.output:
|
||||
if output.root.type == "message":
|
||||
return output.root.content
|
||||
if output.type == "message":
|
||||
return output.content
|
||||
raise TypeError("No output message found in response")
|
||||
|
||||
@classmethod
|
||||
def get_text_from_message_content(
|
||||
cls, message_content: list[OutputContent]
|
||||
) -> str:
|
||||
def get_text_from_message_content(cls, message_content: list[OutputContent]) -> str:
|
||||
"""Extract text content from message content."""
|
||||
for content_item in message_content:
|
||||
if content_item.root.type == "output_text":
|
||||
return str(content_item.root.text)
|
||||
if content_item.type == "output_text":
|
||||
return str(content_item.text)
|
||||
return "No text output found in response"
|
||||
|
||||
@classmethod
|
||||
def tensor_to_input_image_content(
|
||||
cls, image: torch.Tensor, detail_level: Detail = "auto"
|
||||
) -> InputImageContent:
|
||||
def tensor_to_input_image_content(cls, image: torch.Tensor, detail_level: str = "auto") -> InputImageContent:
|
||||
"""Convert a tensor to an input image content object."""
|
||||
return InputImageContent(
|
||||
detail=detail_level,
|
||||
@@ -605,9 +721,9 @@ class OpenAIChatNode(IO.ComfyNode):
|
||||
prompt: str,
|
||||
image: torch.Tensor | None = None,
|
||||
files: list[InputFileContent] | None = None,
|
||||
) -> InputMessageContentList:
|
||||
) -> list[InputTextContent | InputImageContent | InputFileContent]:
|
||||
"""Create a list of input message contents from prompt and optional image."""
|
||||
content_list: list[InputContent | InputTextContent | InputImageContent | InputFileContent] = [
|
||||
content_list: list[InputTextContent | InputImageContent | InputFileContent] = [
|
||||
InputTextContent(text=prompt, type="input_text"),
|
||||
]
|
||||
if image is not None:
|
||||
@@ -619,13 +735,9 @@ class OpenAIChatNode(IO.ComfyNode):
|
||||
type="input_image",
|
||||
)
|
||||
)
|
||||
|
||||
if files is not None:
|
||||
content_list.extend(files)
|
||||
|
||||
return InputMessageContentList(
|
||||
root=content_list,
|
||||
)
|
||||
return content_list
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
@@ -635,7 +747,7 @@ class OpenAIChatNode(IO.ComfyNode):
|
||||
model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value,
|
||||
images: torch.Tensor | None = None,
|
||||
files: list[InputFileContent] | None = None,
|
||||
advanced_options: CreateModelResponseProperties | None = None,
|
||||
advanced_options: ModelResponseProperties | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
@@ -646,36 +758,28 @@ class OpenAIChatNode(IO.ComfyNode):
|
||||
response_model=OpenAIResponse,
|
||||
data=OpenAICreateResponse(
|
||||
input=[
|
||||
Item(
|
||||
root=InputMessage(
|
||||
content=cls.create_input_message_contents(
|
||||
prompt, images, files
|
||||
),
|
||||
role="user",
|
||||
)
|
||||
InputMessage(
|
||||
content=cls.create_input_message_contents(prompt, images, files),
|
||||
role="user",
|
||||
),
|
||||
],
|
||||
store=True,
|
||||
stream=False,
|
||||
model=model,
|
||||
previous_response_id=None,
|
||||
**(
|
||||
advanced_options.model_dump(exclude_none=True)
|
||||
if advanced_options
|
||||
else {}
|
||||
),
|
||||
**(advanced_options.model_dump(exclude_none=True) if advanced_options else {}),
|
||||
),
|
||||
)
|
||||
response_id = create_response.id
|
||||
|
||||
# Get result output
|
||||
result_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{RESPONSES_ENDPOINT}/{response_id}"),
|
||||
response_model=OpenAIResponse,
|
||||
status_extractor=lambda response: response.status,
|
||||
completed_statuses=["incomplete", "completed"]
|
||||
)
|
||||
cls,
|
||||
ApiEndpoint(path=f"{RESPONSES_ENDPOINT}/{response_id}"),
|
||||
response_model=OpenAIResponse,
|
||||
status_extractor=lambda response: response.status,
|
||||
completed_statuses=["incomplete", "completed"],
|
||||
)
|
||||
return IO.NodeOutput(cls.get_text_from_message_content(cls.get_message_content_from_response(result_response)))
|
||||
|
||||
|
||||
@@ -796,7 +900,7 @@ class OpenAIChatConfig(IO.ComfyNode):
|
||||
remove depending on model choice.
|
||||
"""
|
||||
return IO.NodeOutput(
|
||||
CreateModelResponseProperties(
|
||||
ModelResponseProperties(
|
||||
instructions=instructions,
|
||||
truncation=truncation,
|
||||
max_output_tokens=max_output_tokens,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.pixverse_api import (
|
||||
from comfy_api_nodes.apis.pixverse import (
|
||||
PixverseTextVideoRequest,
|
||||
PixverseImageVideoRequest,
|
||||
PixverseTransitionVideoRequest,
|
||||
@@ -128,6 +128,7 @@ class PixverseTextToVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=PRICE_BADGE_VIDEO,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -242,6 +243,7 @@ class PixverseImageToVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=PRICE_BADGE_VIDEO,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -355,6 +357,7 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=PRICE_BADGE_VIDEO,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -416,6 +419,33 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
|
||||
|
||||
|
||||
PRICE_BADGE_VIDEO = IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration_seconds", "quality", "motion_mode"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {
|
||||
"5": {
|
||||
"1080p": {"normal": 1.2, "fast": 1.2},
|
||||
"720p": {"normal": 0.6, "fast": 1.2},
|
||||
"540p": {"normal": 0.45, "fast": 0.9},
|
||||
"360p": {"normal": 0.45, "fast": 0.9}
|
||||
},
|
||||
"8": {
|
||||
"1080p": {"normal": 1.2, "fast": 1.2},
|
||||
"720p": {"normal": 1.2, "fast": 1.2},
|
||||
"540p": {"normal": 0.9, "fast": 1.2},
|
||||
"360p": {"normal": 0.9, "fast": 1.2}
|
||||
}
|
||||
};
|
||||
$durPrices := $lookup($prices, $string(widgets.duration_seconds));
|
||||
$qualityPrices := $lookup($durPrices, widgets.quality);
|
||||
$price := $lookup($qualityPrices, widgets.motion_mode);
|
||||
{"type":"usd","usd": $price ? $price : 0.9}
|
||||
)
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class PixVerseExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing_extensions import override
|
||||
|
||||
from comfy.utils import ProgressBar
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.recraft_api import (
|
||||
from comfy_api_nodes.apis.recraft import (
|
||||
RecraftColor,
|
||||
RecraftColorChain,
|
||||
RecraftControls,
|
||||
@@ -378,6 +378,10 @@ class RecraftTextToImageNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["n"]),
|
||||
expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -490,6 +494,10 @@ class RecraftImageToImageNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["n"]),
|
||||
expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -591,6 +599,10 @@ class RecraftImageInpaintingNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["n"]),
|
||||
expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -692,6 +704,10 @@ class RecraftTextToVectorNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["n"]),
|
||||
expr="""{"type":"usd","usd": $round(0.08 * widgets.n, 2)}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -759,6 +775,10 @@ class RecraftVectorizeImageNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(),
|
||||
expr="""{"type":"usd","usd": 0.01}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -817,6 +837,9 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.04}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -883,6 +906,9 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.01}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -929,6 +955,9 @@ class RecraftCrispUpscaleNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.004}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -972,6 +1001,9 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.25}""",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from typing import Optional
|
||||
from io import BytesIO
|
||||
from typing_extensions import override
|
||||
from PIL import Image
|
||||
from comfy_api_nodes.apis.rodin_api import (
|
||||
from comfy_api_nodes.apis.rodin import (
|
||||
Rodin3DGenerateRequest,
|
||||
Rodin3DGenerateResponse,
|
||||
Rodin3DCheckStatusRequest,
|
||||
@@ -241,6 +241,9 @@ class Rodin3D_Regular(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.4}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -294,6 +297,9 @@ class Rodin3D_Detail(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.4}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -347,6 +353,9 @@ class Rodin3D_Smooth(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.4}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -406,6 +415,9 @@ class Rodin3D_Sketch(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.4}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -16,7 +16,7 @@ from enum import Enum
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||
from comfy_api_nodes.apis import (
|
||||
from comfy_api_nodes.apis.runway import (
|
||||
RunwayImageToVideoRequest,
|
||||
RunwayImageToVideoResponse,
|
||||
RunwayTaskStatusResponse as TaskStatusResponse,
|
||||
@@ -184,6 +184,10 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration"]),
|
||||
expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -274,6 +278,10 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration"]),
|
||||
expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -372,6 +380,10 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration"]),
|
||||
expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -457,6 +469,9 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.11}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -89,6 +89,24 @@ class OpenAIVideoSora2(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "size", "duration"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$size := widgets.size;
|
||||
$dur := widgets.duration;
|
||||
$isPro := $contains($m, "sora-2-pro");
|
||||
$isSora2 := $contains($m, "sora-2");
|
||||
$isProSize := ($size = "1024x1792" or $size = "1792x1024");
|
||||
$perSec :=
|
||||
$isPro ? ($isProSize ? 0.5 : 0.3) :
|
||||
$isSora2 ? 0.1 :
|
||||
($isProSize ? 0.5 : 0.1);
|
||||
{"type":"usd","usd": $round($perSec * $dur, 2)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Optional
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, Input, IO
|
||||
from comfy_api_nodes.apis.stability_api import (
|
||||
from comfy_api_nodes.apis.stability import (
|
||||
StabilityUpscaleConservativeRequest,
|
||||
StabilityUpscaleCreativeRequest,
|
||||
StabilityAsyncResponse,
|
||||
@@ -127,6 +127,9 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.08}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -264,6 +267,16 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$contains(widgets.model,"large")
|
||||
? {"type":"usd","usd":0.065}
|
||||
: {"type":"usd","usd":0.035}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -382,6 +395,9 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.25}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -486,6 +502,9 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.25}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -566,6 +585,9 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.01}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -648,6 +670,9 @@ class StabilityTextToAudio(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.2}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -732,6 +757,9 @@ class StabilityAudioToAudio(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.2}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -828,6 +856,9 @@ class StabilityAudioInpaint(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.2}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -5,7 +5,24 @@ import aiohttp
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis import topaz_api
|
||||
from comfy_api_nodes.apis.topaz import (
|
||||
CreateVideoRequest,
|
||||
CreateVideoRequestSource,
|
||||
CreateVideoResponse,
|
||||
ImageAsyncTaskResponse,
|
||||
ImageDownloadResponse,
|
||||
ImageEnhanceRequest,
|
||||
ImageStatusResponse,
|
||||
OutputInformationVideo,
|
||||
Resolution,
|
||||
VideoAcceptResponse,
|
||||
VideoCompleteUploadRequest,
|
||||
VideoCompleteUploadRequestPart,
|
||||
VideoCompleteUploadResponse,
|
||||
VideoEnhancementFilter,
|
||||
VideoFrameInterpolationFilter,
|
||||
VideoStatusResponse,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
@@ -153,13 +170,13 @@ class TopazImageEnhance(IO.ComfyNode):
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Only one input image is supported.")
|
||||
download_url = await upload_images_to_comfyapi(
|
||||
cls, image, max_images=1, mime_type="image/png", total_pixels=4096*4096
|
||||
cls, image, max_images=1, mime_type="image/png", total_pixels=4096 * 4096
|
||||
)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"),
|
||||
response_model=topaz_api.ImageAsyncTaskResponse,
|
||||
data=topaz_api.ImageEnhanceRequest(
|
||||
response_model=ImageAsyncTaskResponse,
|
||||
data=ImageEnhanceRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
subject_detection=subject_detection,
|
||||
@@ -181,7 +198,7 @@ class TopazImageEnhance(IO.ComfyNode):
|
||||
await poll_op(
|
||||
cls,
|
||||
poll_endpoint=ApiEndpoint(path=f"/proxy/topaz/image/v1/status/{initial_response.process_id}"),
|
||||
response_model=topaz_api.ImageStatusResponse,
|
||||
response_model=ImageStatusResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
progress_extractor=lambda x: getattr(x, "progress", 0),
|
||||
price_extractor=lambda x: x.credits * 0.08,
|
||||
@@ -193,7 +210,7 @@ class TopazImageEnhance(IO.ComfyNode):
|
||||
results = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/topaz/image/v1/download/{initial_response.process_id}"),
|
||||
response_model=topaz_api.ImageDownloadResponse,
|
||||
response_model=ImageDownloadResponse,
|
||||
monitor_progress=False,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(results.download_url))
|
||||
@@ -331,7 +348,7 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
if target_height % 2 != 0:
|
||||
target_height += 1
|
||||
filters.append(
|
||||
topaz_api.VideoEnhancementFilter(
|
||||
VideoEnhancementFilter(
|
||||
model=UPSCALER_MODELS_MAP[upscaler_model],
|
||||
creativity=(upscaler_creativity if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None),
|
||||
isOptimizedMode=(True if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None),
|
||||
@@ -340,7 +357,7 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
if interpolation_enabled:
|
||||
target_frame_rate = interpolation_frame_rate
|
||||
filters.append(
|
||||
topaz_api.VideoFrameInterpolationFilter(
|
||||
VideoFrameInterpolationFilter(
|
||||
model=interpolation_model,
|
||||
slowmo=interpolation_slowmo,
|
||||
fps=interpolation_frame_rate,
|
||||
@@ -351,19 +368,19 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/topaz/video/", method="POST"),
|
||||
response_model=topaz_api.CreateVideoResponse,
|
||||
data=topaz_api.CreateVideoRequest(
|
||||
source=topaz_api.CreateCreateVideoRequestSource(
|
||||
response_model=CreateVideoResponse,
|
||||
data=CreateVideoRequest(
|
||||
source=CreateVideoRequestSource(
|
||||
container="mp4",
|
||||
size=get_fs_object_size(src_video_stream),
|
||||
duration=int(duration_sec),
|
||||
frameCount=video.get_frame_count(),
|
||||
frameRate=src_frame_rate,
|
||||
resolution=topaz_api.Resolution(width=src_width, height=src_height),
|
||||
resolution=Resolution(width=src_width, height=src_height),
|
||||
),
|
||||
filters=filters,
|
||||
output=topaz_api.OutputInformationVideo(
|
||||
resolution=topaz_api.Resolution(width=target_width, height=target_height),
|
||||
output=OutputInformationVideo(
|
||||
resolution=Resolution(width=target_width, height=target_height),
|
||||
frameRate=target_frame_rate,
|
||||
audioCodec="AAC",
|
||||
audioTransfer="Copy",
|
||||
@@ -379,7 +396,7 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
path=f"/proxy/topaz/video/{initial_res.requestId}/accept",
|
||||
method="PATCH",
|
||||
),
|
||||
response_model=topaz_api.VideoAcceptResponse,
|
||||
response_model=VideoAcceptResponse,
|
||||
wait_label="Preparing upload",
|
||||
final_label_on_success="Upload started",
|
||||
)
|
||||
@@ -402,10 +419,10 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload",
|
||||
method="PATCH",
|
||||
),
|
||||
response_model=topaz_api.VideoCompleteUploadResponse,
|
||||
data=topaz_api.VideoCompleteUploadRequest(
|
||||
response_model=VideoCompleteUploadResponse,
|
||||
data=VideoCompleteUploadRequest(
|
||||
uploadResults=[
|
||||
topaz_api.VideoCompleteUploadRequestPart(
|
||||
VideoCompleteUploadRequestPart(
|
||||
partNum=1,
|
||||
eTag=upload_etag,
|
||||
),
|
||||
@@ -417,7 +434,7 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"),
|
||||
response_model=topaz_api.VideoStatusResponse,
|
||||
response_model=VideoStatusResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
progress_extractor=lambda x: getattr(x, "progress", 0),
|
||||
price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None),
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.tripo_api import (
|
||||
from comfy_api_nodes.apis.tripo import (
|
||||
TripoAnimateRetargetRequest,
|
||||
TripoAnimateRigRequest,
|
||||
TripoConvertModelRequest,
|
||||
@@ -117,6 +117,38 @@ class TripoTextToModelNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
is_output_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=[
|
||||
"model_version",
|
||||
"style",
|
||||
"texture",
|
||||
"pbr",
|
||||
"quad",
|
||||
"texture_quality",
|
||||
"geometry_quality",
|
||||
],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$isV14 := $contains(widgets.model_version,"v1.4");
|
||||
$style := widgets.style;
|
||||
$hasStyle := ($style != "" and $style != "none");
|
||||
$withTexture := widgets.texture or widgets.pbr;
|
||||
$isHdTexture := (widgets.texture_quality = "detailed");
|
||||
$isDetailedGeometry := (widgets.geometry_quality = "detailed");
|
||||
$baseCredits :=
|
||||
$isV14 ? 20 : ($withTexture ? 20 : 10);
|
||||
$credits :=
|
||||
$baseCredits
|
||||
+ ($hasStyle ? 5 : 0)
|
||||
+ (widgets.quad ? 5 : 0)
|
||||
+ ($isHdTexture ? 10 : 0)
|
||||
+ ($isDetailedGeometry ? 20 : 0);
|
||||
{"type":"usd","usd": $round($credits * 0.01, 2)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -210,6 +242,38 @@ class TripoImageToModelNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
is_output_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=[
|
||||
"model_version",
|
||||
"style",
|
||||
"texture",
|
||||
"pbr",
|
||||
"quad",
|
||||
"texture_quality",
|
||||
"geometry_quality",
|
||||
],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$isV14 := $contains(widgets.model_version,"v1.4");
|
||||
$style := widgets.style;
|
||||
$hasStyle := ($style != "" and $style != "none");
|
||||
$withTexture := widgets.texture or widgets.pbr;
|
||||
$isHdTexture := (widgets.texture_quality = "detailed");
|
||||
$isDetailedGeometry := (widgets.geometry_quality = "detailed");
|
||||
$baseCredits :=
|
||||
$isV14 ? 30 : ($withTexture ? 30 : 20);
|
||||
$credits :=
|
||||
$baseCredits
|
||||
+ ($hasStyle ? 5 : 0)
|
||||
+ (widgets.quad ? 5 : 0)
|
||||
+ ($isHdTexture ? 10 : 0)
|
||||
+ ($isDetailedGeometry ? 20 : 0);
|
||||
{"type":"usd","usd": $round($credits * 0.01, 2)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -314,6 +378,34 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
is_output_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=[
|
||||
"model_version",
|
||||
"texture",
|
||||
"pbr",
|
||||
"quad",
|
||||
"texture_quality",
|
||||
"geometry_quality",
|
||||
],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$isV14 := $contains(widgets.model_version,"v1.4");
|
||||
$withTexture := widgets.texture or widgets.pbr;
|
||||
$isHdTexture := (widgets.texture_quality = "detailed");
|
||||
$isDetailedGeometry := (widgets.geometry_quality = "detailed");
|
||||
$baseCredits :=
|
||||
$isV14 ? 30 : ($withTexture ? 30 : 20);
|
||||
$credits :=
|
||||
$baseCredits
|
||||
+ (widgets.quad ? 5 : 0)
|
||||
+ ($isHdTexture ? 10 : 0)
|
||||
+ ($isDetailedGeometry ? 20 : 0);
|
||||
{"type":"usd","usd": $round($credits * 0.01, 2)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -405,6 +497,15 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
is_output_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["texture_quality"]),
|
||||
expr="""
|
||||
(
|
||||
$tq := widgets.texture_quality;
|
||||
{"type":"usd","usd": ($contains($tq,"detailed") ? 0.2 : 0.1)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -456,6 +557,9 @@ class TripoRefineNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
is_output_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.3}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -489,6 +593,9 @@ class TripoRigNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
is_output_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.25}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -545,6 +652,9 @@ class TripoRetargetNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
is_output_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.1}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -638,6 +748,60 @@ class TripoConversionNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
is_output_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=[
|
||||
"quad",
|
||||
"face_limit",
|
||||
"texture_size",
|
||||
"texture_format",
|
||||
"force_symmetry",
|
||||
"flatten_bottom",
|
||||
"flatten_bottom_threshold",
|
||||
"pivot_to_center_bottom",
|
||||
"scale_factor",
|
||||
"with_animation",
|
||||
"pack_uv",
|
||||
"bake",
|
||||
"part_names",
|
||||
"fbx_preset",
|
||||
"export_vertex_colors",
|
||||
"export_orientation",
|
||||
"animate_in_place",
|
||||
],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$face := (widgets.face_limit != null) ? widgets.face_limit : -1;
|
||||
$texSize := (widgets.texture_size != null) ? widgets.texture_size : 4096;
|
||||
$flatThresh := (widgets.flatten_bottom_threshold != null) ? widgets.flatten_bottom_threshold : 0;
|
||||
$scale := (widgets.scale_factor != null) ? widgets.scale_factor : 1;
|
||||
$texFmt := (widgets.texture_format != "" ? widgets.texture_format : "jpeg");
|
||||
$part := widgets.part_names;
|
||||
$fbx := (widgets.fbx_preset != "" ? widgets.fbx_preset : "blender");
|
||||
$orient := (widgets.export_orientation != "" ? widgets.export_orientation : "default");
|
||||
$advanced :=
|
||||
widgets.quad or
|
||||
widgets.force_symmetry or
|
||||
widgets.flatten_bottom or
|
||||
widgets.pivot_to_center_bottom or
|
||||
widgets.with_animation or
|
||||
widgets.pack_uv or
|
||||
widgets.bake or
|
||||
widgets.export_vertex_colors or
|
||||
widgets.animate_in_place or
|
||||
($face != -1) or
|
||||
($texSize != 4096) or
|
||||
($flatThresh != 0) or
|
||||
($scale != 1) or
|
||||
($texFmt != "jpeg") or
|
||||
($part != "") or
|
||||
($fbx != "blender") or
|
||||
($orient != "default");
|
||||
{"type":"usd","usd": ($advanced ? 0.1 : 0.05)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -4,7 +4,7 @@ from io import BytesIO
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||
from comfy_api_nodes.apis.veo_api import (
|
||||
from comfy_api_nodes.apis.veo import (
|
||||
VeoGenVidPollRequest,
|
||||
VeoGenVidPollResponse,
|
||||
VeoGenVidRequest,
|
||||
@@ -122,6 +122,10 @@ class VeoVideoGenerationNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration_seconds"]),
|
||||
expr="""{"type":"usd","usd": 0.5 * widgets.duration_seconds}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -347,6 +351,20 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$a := widgets.generate_audio;
|
||||
($contains($m,"veo-3.0-fast-generate-001") or $contains($m,"veo-3.1-fast-generate"))
|
||||
? {"type":"usd","usd": ($a ? 1.2 : 0.8)}
|
||||
: ($contains($m,"veo-3.0-generate-001") or $contains($m,"veo-3.1-generate"))
|
||||
? {"type":"usd","usd": ($a ? 3.2 : 1.6)}
|
||||
: {"type":"range_usd","min_usd":0.8,"max_usd":3.2}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -420,6 +438,30 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {
|
||||
"veo-3.1-fast-generate": { "audio": 0.15, "no_audio": 0.10 },
|
||||
"veo-3.1-generate": { "audio": 0.40, "no_audio": 0.20 }
|
||||
};
|
||||
$m := widgets.model;
|
||||
$ga := (widgets.generate_audio = "true");
|
||||
$seconds := widgets.duration;
|
||||
$modelKey :=
|
||||
$contains($m, "veo-3.1-fast-generate") ? "veo-3.1-fast-generate" :
|
||||
$contains($m, "veo-3.1-generate") ? "veo-3.1-generate" :
|
||||
"";
|
||||
$audioKey := $ga ? "audio" : "no_audio";
|
||||
$modelPrices := $lookup($prices, $modelKey);
|
||||
$pps := $lookup($modelPrices, $audioKey);
|
||||
($pps != null)
|
||||
? {"type":"usd","usd": $pps * $seconds}
|
||||
: {"type":"range_usd","min_usd": 0.4, "max_usd": 3.2}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -121,6 +121,9 @@ class ViduTextToVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.4}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -214,6 +217,9 @@ class ViduImageToVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.4}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -317,6 +323,9 @@ class ViduReferenceVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.4}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -426,6 +435,9 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.4}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -507,6 +519,17 @@ class Vidu2TextToVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$is1080 := widgets.resolution = "1080p";
|
||||
$base := $is1080 ? 0.1 : 0.075;
|
||||
$perSec := $is1080 ? 0.05 : 0.025;
|
||||
{"type":"usd","usd": $base + $perSec * (widgets.duration - 1)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -594,6 +617,39 @@ class Vidu2ImageToVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$d := widgets.duration;
|
||||
$is1080 := widgets.resolution = "1080p";
|
||||
$contains($m, "pro-fast")
|
||||
? (
|
||||
$base := $is1080 ? 0.08 : 0.04;
|
||||
$perSec := $is1080 ? 0.02 : 0.01;
|
||||
{"type":"usd","usd": $base + $perSec * ($d - 1)}
|
||||
)
|
||||
: $contains($m, "pro")
|
||||
? (
|
||||
$base := $is1080 ? 0.275 : 0.075;
|
||||
$perSec := $is1080 ? 0.075 : 0.05;
|
||||
{"type":"usd","usd": $base + $perSec * ($d - 1)}
|
||||
)
|
||||
: $contains($m, "turbo")
|
||||
? (
|
||||
$is1080
|
||||
? {"type":"usd","usd": 0.175 + 0.05 * ($d - 1)}
|
||||
: (
|
||||
$d <= 1 ? {"type":"usd","usd": 0.04}
|
||||
: $d <= 2 ? {"type":"usd","usd": 0.05}
|
||||
: {"type":"usd","usd": 0.05 + 0.05 * ($d - 2)}
|
||||
)
|
||||
)
|
||||
: {"type":"usd","usd": 0.04}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -647,7 +703,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode):
|
||||
"subjects",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("reference_images"),
|
||||
names=["subject1", "subject2", "subject3"],
|
||||
names=["subject1", "subject2", "subject3", "subject4", "subject5", "subject6", "subject7"],
|
||||
min=1,
|
||||
),
|
||||
tooltip="For each subject, provide up to 3 reference images (7 images total across all subjects). "
|
||||
@@ -682,7 +738,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode):
|
||||
control_after_generate=True,
|
||||
),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "4:3", "3:4", "1:1"]),
|
||||
IO.Combo.Input("resolution", options=["720p"]),
|
||||
IO.Combo.Input("resolution", options=["720p", "1080p"]),
|
||||
IO.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=["auto", "small", "medium", "large"],
|
||||
@@ -698,6 +754,18 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["audio", "duration", "resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$is1080 := widgets.resolution = "1080p";
|
||||
$base := $is1080 ? 0.375 : 0.125;
|
||||
$perSec := $is1080 ? 0.05 : 0.025;
|
||||
$audioCost := widgets.audio = true ? 0.075 : 0;
|
||||
{"type":"usd","usd": $base + $perSec * (widgets.duration - 1) + $audioCost}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -804,6 +872,38 @@ class Vidu2StartEndToVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$d := widgets.duration;
|
||||
$is1080 := widgets.resolution = "1080p";
|
||||
$contains($m, "pro-fast")
|
||||
? (
|
||||
$base := $is1080 ? 0.08 : 0.04;
|
||||
$perSec := $is1080 ? 0.02 : 0.01;
|
||||
{"type":"usd","usd": $base + $perSec * ($d - 1)}
|
||||
)
|
||||
: $contains($m, "pro")
|
||||
? (
|
||||
$base := $is1080 ? 0.275 : 0.075;
|
||||
$perSec := $is1080 ? 0.075 : 0.05;
|
||||
{"type":"usd","usd": $base + $perSec * ($d - 1)}
|
||||
)
|
||||
: $contains($m, "turbo")
|
||||
? (
|
||||
$is1080
|
||||
? {"type":"usd","usd": 0.175 + 0.05 * ($d - 1)}
|
||||
: (
|
||||
$d <= 2 ? {"type":"usd","usd": 0.05}
|
||||
: {"type":"usd","usd": 0.05 + 0.05 * ($d - 2)}
|
||||
)
|
||||
)
|
||||
: {"type":"usd","usd": 0.04}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -244,6 +244,9 @@ class WanTextToImageApi(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.03}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -363,6 +366,9 @@ class WanImageToImageApi(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.03}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -520,6 +526,17 @@ class WanTextToVideoApi(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration", "size"]),
|
||||
expr="""
|
||||
(
|
||||
$ppsTable := { "480p": 0.05, "720p": 0.1, "1080p": 0.15 };
|
||||
$resKey := $substringBefore(widgets.size, ":");
|
||||
$pps := $lookup($ppsTable, $resKey);
|
||||
{ "type": "usd", "usd": $round($pps * widgets.duration, 2) }
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -681,6 +698,16 @@ class WanImageToVideoApi(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$ppsTable := { "480p": 0.05, "720p": 0.1, "1080p": 0.15 };
|
||||
$pps := $lookup($ppsTable, widgets.resolution);
|
||||
{ "type": "usd", "usd": $round($pps * widgets.duration, 2) }
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -828,6 +855,22 @@ class WanReferenceVideoApi(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["size", "duration"]),
|
||||
expr="""
|
||||
(
|
||||
$rate := $contains(widgets.size, "1080p") ? 0.15 : 0.10;
|
||||
$inputMin := 2 * $rate;
|
||||
$inputMax := 5 * $rate;
|
||||
$outputPrice := widgets.duration * $rate;
|
||||
{
|
||||
"type": "range_usd",
|
||||
"min_usd": $inputMin + $outputPrice,
|
||||
"max_usd": $inputMax + $outputPrice
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
178
comfy_api_nodes/nodes_wavespeed.py
Normal file
178
comfy_api_nodes/nodes_wavespeed.py
Normal file
@@ -0,0 +1,178 @@
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.wavespeed import (
|
||||
FlashVSRRequest,
|
||||
TaskCreatedResponse,
|
||||
TaskResultResponse,
|
||||
SeedVR2ImageRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_video_to_comfyapi,
|
||||
validate_container_format_is_mp4,
|
||||
validate_video_duration,
|
||||
upload_images_to_comfyapi,
|
||||
get_number_of_images,
|
||||
download_url_to_image_tensor,
|
||||
)
|
||||
|
||||
|
||||
class WavespeedFlashVSRNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="WavespeedFlashVSRNode",
|
||||
display_name="FlashVSR Video Upscale",
|
||||
category="api node/video/WaveSpeed",
|
||||
description="Fast, high-quality video upscaler that "
|
||||
"boosts resolution and restores clarity for low-resolution or blurry footage.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
IO.Combo.Input("target_resolution", options=["720p", "1080p", "2K", "4K"]),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["target_resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$price_for_1sec := {"720p": 0.012, "1080p": 0.018, "2k": 0.024, "4k": 0.032};
|
||||
{
|
||||
"type":"usd",
|
||||
"usd": $lookup($price_for_1sec, widgets.target_resolution),
|
||||
"format":{"suffix": "/second", "approximate": true}
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: Input.Video,
|
||||
target_resolution: str,
|
||||
) -> IO.NodeOutput:
|
||||
validate_container_format_is_mp4(video)
|
||||
validate_video_duration(video, min_duration=5, max_duration=60 * 10)
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/wavespeed/api/v3/wavespeed-ai/flashvsr", method="POST"),
|
||||
response_model=TaskCreatedResponse,
|
||||
data=FlashVSRRequest(
|
||||
target_resolution=target_resolution.lower(),
|
||||
video=await upload_video_to_comfyapi(cls, video),
|
||||
duration=video.get_duration(),
|
||||
),
|
||||
)
|
||||
if initial_res.code != 200:
|
||||
raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}")
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"),
|
||||
response_model=TaskResultResponse,
|
||||
status_extractor=lambda x: "failed" if x.data is None else x.data.status,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
if final_response.code != 200:
|
||||
raise ValueError(
|
||||
f"Task processing failed with code={final_response.code} and message={final_response.message}"
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.outputs[0]))
|
||||
|
||||
|
||||
class WavespeedImageUpscaleNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="WavespeedImageUpscaleNode",
|
||||
display_name="WaveSpeed Image Upscale",
|
||||
category="api node/image/WaveSpeed",
|
||||
description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]),
|
||||
IO.Image.Input("image"),
|
||||
IO.Combo.Input("target_resolution", options=["2K", "4K", "8K"]),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {"seedvr2": 0.01, "ultimate": 0.06};
|
||||
{"type":"usd", "usd": $lookup($prices, widgets.model)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: Input.Image,
|
||||
target_resolution: str,
|
||||
) -> IO.NodeOutput:
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Exactly one input image is required.")
|
||||
if model == "SeedVR2":
|
||||
model_path = "seedvr2/image"
|
||||
else:
|
||||
model_path = "ultimate-image-upscaler"
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/wavespeed-ai/{model_path}", method="POST"),
|
||||
response_model=TaskCreatedResponse,
|
||||
data=SeedVR2ImageRequest(
|
||||
target_resolution=target_resolution.lower(),
|
||||
image=(await upload_images_to_comfyapi(cls, image, max_images=1))[0],
|
||||
),
|
||||
)
|
||||
if initial_res.code != 200:
|
||||
raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}")
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"),
|
||||
response_model=TaskResultResponse,
|
||||
status_extractor=lambda x: "failed" if x.data is None else x.data.status,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
if final_response.code != 200:
|
||||
raise ValueError(
|
||||
f"Task processing failed with code={final_response.code} and message={final_response.message}"
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.outputs[0]))
|
||||
|
||||
|
||||
class WavespeedExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
WavespeedFlashVSRNode,
|
||||
WavespeedImageUpscaleNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> WavespeedExtension:
|
||||
return WavespeedExtension()
|
||||
@@ -1,10 +0,0 @@
|
||||
# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes.
|
||||
# This is used for development purposes to generate stubs for unreleased API endpoints.
|
||||
apis:
|
||||
filter:
|
||||
root: openapi.yaml
|
||||
decorators:
|
||||
filter-in:
|
||||
property: tags
|
||||
value: ['API Nodes']
|
||||
matchStrategy: all
|
||||
@@ -1,10 +0,0 @@
|
||||
# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes.
|
||||
|
||||
apis:
|
||||
filter:
|
||||
root: openapi.yaml
|
||||
decorators:
|
||||
filter-in:
|
||||
property: tags
|
||||
value: ['API Nodes', 'Released']
|
||||
matchStrategy: all
|
||||
@@ -11,6 +11,7 @@ from .conversions import (
|
||||
audio_input_to_mp3,
|
||||
audio_to_base64_string,
|
||||
bytesio_to_image_tensor,
|
||||
convert_mask_to_image,
|
||||
downscale_image_tensor,
|
||||
image_tensor_pair_to_batch,
|
||||
pil_to_bytesio,
|
||||
@@ -72,6 +73,7 @@ __all__ = [
|
||||
"audio_input_to_mp3",
|
||||
"audio_to_base64_string",
|
||||
"bytesio_to_image_tensor",
|
||||
"convert_mask_to_image",
|
||||
"downscale_image_tensor",
|
||||
"image_tensor_pair_to_batch",
|
||||
"pil_to_bytesio",
|
||||
|
||||
@@ -451,6 +451,12 @@ def resize_mask_to_image(
|
||||
return mask
|
||||
|
||||
|
||||
def convert_mask_to_image(mask: Input.Image) -> torch.Tensor:
|
||||
"""Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image."""
|
||||
mask = mask.unsqueeze(-1)
|
||||
return torch.cat([mask] * 3, dim=-1)
|
||||
|
||||
|
||||
def text_filepath_to_base64_string(filepath: str) -> str:
|
||||
"""Converts a text file to a base64 string."""
|
||||
with open(filepath, "rb") as f:
|
||||
|
||||
@@ -43,7 +43,7 @@ class UploadResponse(BaseModel):
|
||||
|
||||
async def upload_images_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
image: torch.Tensor,
|
||||
image: torch.Tensor | list[torch.Tensor],
|
||||
*,
|
||||
max_images: int = 8,
|
||||
mime_type: str | None = None,
|
||||
@@ -55,15 +55,28 @@ async def upload_images_to_comfyapi(
|
||||
Uploads images to ComfyUI API and returns download URLs.
|
||||
To upload multiple images, stack them in the batch dimension first.
|
||||
"""
|
||||
tensors: list[torch.Tensor] = []
|
||||
if isinstance(image, list):
|
||||
for img in image:
|
||||
is_batch = len(img.shape) > 3
|
||||
if is_batch:
|
||||
tensors.extend(img[i] for i in range(img.shape[0]))
|
||||
else:
|
||||
tensors.append(img)
|
||||
else:
|
||||
is_batch = len(image.shape) > 3
|
||||
if is_batch:
|
||||
tensors.extend(image[i] for i in range(image.shape[0]))
|
||||
else:
|
||||
tensors.append(image)
|
||||
|
||||
# if batched, try to upload each file if max_images is greater than 0
|
||||
download_urls: list[str] = []
|
||||
is_batch = len(image.shape) > 3
|
||||
batch_len = image.shape[0] if is_batch else 1
|
||||
num_to_upload = min(batch_len, max_images)
|
||||
num_to_upload = min(len(tensors), max_images)
|
||||
batch_start_ts = time.monotonic()
|
||||
|
||||
for idx in range(num_to_upload):
|
||||
tensor = image[idx] if is_batch else image
|
||||
tensor = tensors[idx]
|
||||
img_io = tensor_to_bytesio(tensor, total_pixels=total_pixels, mime_type=mime_type)
|
||||
|
||||
effective_label = wait_label
|
||||
|
||||
@@ -29,8 +29,10 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
do_easycache = easycache.should_do_easycache(sigmas)
|
||||
if do_easycache:
|
||||
easycache.check_metadata(x)
|
||||
# if there isn't a cache diff for current conds, we cannot skip this step
|
||||
can_apply_cache_diff = easycache.can_apply_cache_diff(uuids)
|
||||
# if first cond marked this step for skipping, skip it and use appropriate cached values
|
||||
if easycache.skip_current_step:
|
||||
if easycache.skip_current_step and can_apply_cache_diff:
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
|
||||
return easycache.apply_cache_diff(x, uuids)
|
||||
@@ -44,7 +46,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
|
||||
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
||||
easycache.cumulative_change_rate += approx_output_change_rate
|
||||
if easycache.cumulative_change_rate < easycache.reuse_threshold:
|
||||
if easycache.cumulative_change_rate < easycache.reuse_threshold and can_apply_cache_diff:
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||
# other conds should also skip this step, and instead use their cached values
|
||||
@@ -240,6 +242,9 @@ class EasyCacheHolder:
|
||||
return to_return.clone()
|
||||
return to_return
|
||||
|
||||
def can_apply_cache_diff(self, uuids: list[UUID]) -> bool:
|
||||
return all(uuid in self.uuid_cache_diffs for uuid in uuids)
|
||||
|
||||
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
|
||||
if self.first_cond_uuid in uuids:
|
||||
self.total_steps_skipped += 1
|
||||
|
||||
@@ -7,6 +7,7 @@ import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.latent_formats
|
||||
import comfy.ldm.lumina.controlnet
|
||||
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
|
||||
|
||||
|
||||
class BlockWiseControlBlock(torch.nn.Module):
|
||||
@@ -244,6 +245,10 @@ class ModelPatchLoader:
|
||||
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
|
||||
sd = z_image_convert(sd)
|
||||
config = {}
|
||||
if 'control_layers.4.adaLN_modulation.0.weight' not in sd:
|
||||
config['n_control_layers'] = 3
|
||||
config['additional_in_dim'] = 17
|
||||
config['refiner_control'] = True
|
||||
if 'control_layers.14.adaLN_modulation.0.weight' in sd:
|
||||
config['n_control_layers'] = 15
|
||||
config['additional_in_dim'] = 17
|
||||
@@ -253,6 +258,14 @@ class ModelPatchLoader:
|
||||
if torch.count_nonzero(ref_weight) == 0:
|
||||
config['broken'] = True
|
||||
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
|
||||
elif "audio_proj.proj1.weight" in sd:
|
||||
model = MultiTalkModelPatch(
|
||||
audio_window=5, context_tokens=32, vae_scale=4,
|
||||
in_dim=sd["blocks.0.audio_cross_attn.proj.weight"].shape[0],
|
||||
intermediate_dim=sd["audio_proj.proj1.weight"].shape[0],
|
||||
out_dim=sd["audio_proj.norm.weight"].shape[0],
|
||||
device=comfy.model_management.unet_offload_device(),
|
||||
operations=comfy.ops.manual_cast)
|
||||
|
||||
model.load_state_dict(sd)
|
||||
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
@@ -520,6 +533,38 @@ class USOStyleReference:
|
||||
return (model_patched,)
|
||||
|
||||
|
||||
class MultiTalkModelPatch(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
audio_window: int = 5,
|
||||
intermediate_dim: int = 512,
|
||||
in_dim: int = 5120,
|
||||
out_dim: int = 768,
|
||||
context_tokens: int = 32,
|
||||
vae_scale: int = 4,
|
||||
num_layers: int = 40,
|
||||
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.audio_proj = MultiTalkAudioProjModel(
|
||||
seq_len=audio_window,
|
||||
seq_len_vf=audio_window+vae_scale-1,
|
||||
intermediate_dim=intermediate_dim,
|
||||
out_dim=out_dim,
|
||||
context_tokens=context_tokens,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations
|
||||
)
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[
|
||||
WanMultiTalkAttentionBlock(in_dim, out_dim, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelPatchLoader": ModelPatchLoader,
|
||||
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
||||
|
||||
@@ -254,6 +254,7 @@ class ResizeType(str, Enum):
|
||||
SCALE_HEIGHT = "scale height"
|
||||
SCALE_TOTAL_PIXELS = "scale total pixels"
|
||||
MATCH_SIZE = "match size"
|
||||
SCALE_TO_MULTIPLE = "scale to multiple"
|
||||
|
||||
def is_image(input: torch.Tensor) -> bool:
|
||||
# images have 4 dimensions: [batch, height, width, channels]
|
||||
@@ -328,7 +329,7 @@ def scale_shorter_dimension(input: torch.Tensor, shorter_size: int, scale_method
|
||||
if height < width:
|
||||
width = round((width / height) * shorter_size)
|
||||
height = shorter_size
|
||||
elif width > height:
|
||||
elif width < height:
|
||||
height = round((height / width) * shorter_size)
|
||||
width = shorter_size
|
||||
else:
|
||||
@@ -363,6 +364,43 @@ def scale_match_size(input: torch.Tensor, match: torch.Tensor, scale_method: str
|
||||
input = finalize_image_mask_input(input, is_type_image)
|
||||
return input
|
||||
|
||||
def scale_to_multiple_cover(input: torch.Tensor, multiple: int, scale_method: str) -> torch.Tensor:
|
||||
if multiple <= 1:
|
||||
return input
|
||||
is_type_image = is_image(input)
|
||||
if is_type_image:
|
||||
_, height, width, _ = input.shape
|
||||
else:
|
||||
_, height, width = input.shape
|
||||
target_w = (width // multiple) * multiple
|
||||
target_h = (height // multiple) * multiple
|
||||
if target_w == 0 or target_h == 0:
|
||||
return input
|
||||
if target_w == width and target_h == height:
|
||||
return input
|
||||
s_w = target_w / width
|
||||
s_h = target_h / height
|
||||
if s_w >= s_h:
|
||||
scaled_w = target_w
|
||||
scaled_h = int(math.ceil(height * s_w))
|
||||
if scaled_h < target_h:
|
||||
scaled_h = target_h
|
||||
else:
|
||||
scaled_h = target_h
|
||||
scaled_w = int(math.ceil(width * s_h))
|
||||
if scaled_w < target_w:
|
||||
scaled_w = target_w
|
||||
input = init_image_mask_input(input, is_type_image)
|
||||
input = comfy.utils.common_upscale(input, scaled_w, scaled_h, scale_method, "disabled")
|
||||
input = finalize_image_mask_input(input, is_type_image)
|
||||
x0 = (scaled_w - target_w) // 2
|
||||
y0 = (scaled_h - target_h) // 2
|
||||
x1 = x0 + target_w
|
||||
y1 = y0 + target_h
|
||||
if is_type_image:
|
||||
return input[:, y0:y1, x0:x1, :]
|
||||
return input[:, y0:y1, x0:x1]
|
||||
|
||||
class ResizeImageMaskNode(io.ComfyNode):
|
||||
|
||||
scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||
@@ -378,6 +416,7 @@ class ResizeImageMaskNode(io.ComfyNode):
|
||||
longer_size: int
|
||||
shorter_size: int
|
||||
megapixels: float
|
||||
multiple: int
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@@ -417,6 +456,9 @@ class ResizeImageMaskNode(io.ComfyNode):
|
||||
io.MultiType.Input("match", [io.Image, io.Mask]),
|
||||
crop_combo,
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [
|
||||
io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1),
|
||||
]),
|
||||
]),
|
||||
io.Combo.Input("scale_method", options=cls.scale_methods, default="area"),
|
||||
],
|
||||
@@ -442,6 +484,8 @@ class ResizeImageMaskNode(io.ComfyNode):
|
||||
return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method))
|
||||
elif selected_type == ResizeType.MATCH_SIZE:
|
||||
return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"]))
|
||||
elif selected_type == ResizeType.SCALE_TO_MULTIPLE:
|
||||
return io.NodeOutput(scale_to_multiple_cover(input, resize_type["multiple"], scale_method))
|
||||
raise ValueError(f"Unsupported resize type: {selected_type}")
|
||||
|
||||
def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None:
|
||||
@@ -506,6 +550,7 @@ class BatchImagesNode(io.ComfyNode):
|
||||
node_id="BatchImagesNode",
|
||||
display_name="Batch Images",
|
||||
category="image",
|
||||
search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"],
|
||||
inputs=[
|
||||
io.Autogrow.Input("images", template=autogrow_template)
|
||||
],
|
||||
@@ -592,6 +637,97 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
|
||||
batched = batch_masks(values)
|
||||
return io.NodeOutput(batched)
|
||||
|
||||
|
||||
from comfy_api.latest import node_replace
|
||||
|
||||
def register_replacements():
|
||||
register_replacements_longeredge()
|
||||
register_replacements_batchimages()
|
||||
register_replacements_upscaleimage()
|
||||
register_replacements_controlnet()
|
||||
register_replacements_load3d()
|
||||
register_replacements_preview3d()
|
||||
register_replacements_svdimg2vid()
|
||||
register_replacements_conditioningavg()
|
||||
|
||||
def register_replacements_longeredge():
|
||||
# No dynamic inputs here
|
||||
node_replace.register_node_replacement(node_replace.NodeReplace(
|
||||
new_node_id="ImageScaleToMaxDimension",
|
||||
old_node_id="ResizeImagesByLongerEdge",
|
||||
old_widget_ids=["longer_edge"],
|
||||
input_mapping=[
|
||||
node_replace.InputMap(new_id="image", assign=node_replace.InputMap.OldId("images")),
|
||||
node_replace.InputMap(new_id="largest_size", assign=node_replace.InputMap.OldId("longer_edge")),
|
||||
node_replace.InputMap(new_id="upscale_method", assign=node_replace.InputMap.SetValue("lanczos")),
|
||||
],
|
||||
# just to test the frontend output_mapping code, does nothing really here
|
||||
output_mapping=[node_replace.OutputMap(new_idx=0, old_idx=0)],
|
||||
))
|
||||
|
||||
def register_replacements_batchimages():
|
||||
# BatchImages node uses Autogrow
|
||||
node_replace.register_node_replacement(node_replace.NodeReplace(
|
||||
new_node_id="BatchImagesNode",
|
||||
old_node_id="ImageBatch",
|
||||
input_mapping=[
|
||||
node_replace.InputMap(new_id="images.image0", assign=node_replace.InputMap.OldId("image1")),
|
||||
node_replace.InputMap(new_id="images.image1", assign=node_replace.InputMap.OldId("image2")),
|
||||
],
|
||||
))
|
||||
|
||||
def register_replacements_upscaleimage():
|
||||
# ResizeImageMaskNode uses DynamicCombo
|
||||
node_replace.register_node_replacement(node_replace.NodeReplace(
|
||||
new_node_id="ResizeImageMaskNode",
|
||||
old_node_id="ImageScaleBy",
|
||||
old_widget_ids=["upscale_method", "scale_by"],
|
||||
input_mapping=[
|
||||
node_replace.InputMap(new_id="input", assign=node_replace.InputMap.OldId("image")),
|
||||
node_replace.InputMap(new_id="resize_type", assign=node_replace.InputMap.SetValue("scale by multiplier")),
|
||||
node_replace.InputMap(new_id="resize_type.multiplier", assign=node_replace.InputMap.OldId("scale_by")),
|
||||
node_replace.InputMap(new_id="scale_method", assign=node_replace.InputMap.OldId("upscale_method")),
|
||||
],
|
||||
))
|
||||
|
||||
def register_replacements_controlnet():
|
||||
# T2IAdapterLoader → ControlNetLoader
|
||||
node_replace.register_node_replacement(node_replace.NodeReplace(
|
||||
new_node_id="ControlNetLoader",
|
||||
old_node_id="T2IAdapterLoader",
|
||||
input_mapping=[
|
||||
node_replace.InputMap(new_id="control_net_name", assign=node_replace.InputMap.OldId("t2i_adapter_name")),
|
||||
],
|
||||
))
|
||||
|
||||
def register_replacements_load3d():
|
||||
# Load3DAnimation merged into Load3D
|
||||
node_replace.register_node_replacement(node_replace.NodeReplace(
|
||||
new_node_id="Load3D",
|
||||
old_node_id="Load3DAnimation",
|
||||
))
|
||||
|
||||
def register_replacements_preview3d():
|
||||
# Preview3DAnimation merged into Preview3D
|
||||
node_replace.register_node_replacement(node_replace.NodeReplace(
|
||||
new_node_id="Preview3D",
|
||||
old_node_id="Preview3DAnimation",
|
||||
))
|
||||
|
||||
def register_replacements_svdimg2vid():
|
||||
# Typo fix: SDV → SVD
|
||||
node_replace.register_node_replacement(node_replace.NodeReplace(
|
||||
new_node_id="SVD_img2vid_Conditioning",
|
||||
old_node_id="SDV_img2vid_Conditioning",
|
||||
))
|
||||
|
||||
def register_replacements_conditioningavg():
|
||||
# Typo fix: trailing space in node name
|
||||
node_replace.register_node_replacement(node_replace.NodeReplace(
|
||||
new_node_id="ConditioningAverage",
|
||||
old_node_id="ConditioningAverage ",
|
||||
))
|
||||
|
||||
class PostProcessingExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
|
||||
@@ -16,6 +16,7 @@ class PreviewAny():
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "utils"
|
||||
SEARCH_ALIASES = ["preview", "show", "display", "view", "show text", "display text", "preview text", "show output", "inspect", "debug"]
|
||||
|
||||
def main(self, source=None):
|
||||
value = 'None'
|
||||
|
||||
@@ -11,6 +11,7 @@ class StringConcatenate(io.ComfyNode):
|
||||
node_id="StringConcatenate",
|
||||
display_name="Concatenate",
|
||||
category="utils/string",
|
||||
search_aliases=["text concat", "join text", "merge text", "combine strings", "concat", "concatenate", "append text", "combine text", "string"],
|
||||
inputs=[
|
||||
io.String.Input("string_a", multiline=True),
|
||||
io.String.Input("string_b", multiline=True),
|
||||
|
||||
@@ -53,6 +53,7 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
||||
node_id="ImageUpscaleWithModel",
|
||||
display_name="Upscale Image (using Model)",
|
||||
category="image/upscaling",
|
||||
search_aliases=["upscale", "upscaler", "upsc", "enlarge image", "super resolution", "hires", "superres", "increase resolution"],
|
||||
inputs=[
|
||||
io.UpscaleModel.Input("upscale_model"),
|
||||
io.Image.Input("image"),
|
||||
|
||||
@@ -8,9 +8,10 @@ import comfy.latent_formats
|
||||
import comfy.clip_vision
|
||||
import json
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
from typing import Tuple, TypedDict
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import logging
|
||||
|
||||
class WanImageToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -1288,6 +1289,171 @@ class Wan22ImageToVideoLatent(io.ComfyNode):
|
||||
return io.NodeOutput(out_latent)
|
||||
|
||||
|
||||
from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, MultiTalkGetAttnMapPatch, project_audio_features
|
||||
class WanInfiniteTalkToVideo(io.ComfyNode):
|
||||
class DCValues(TypedDict):
|
||||
mode: str
|
||||
audio_encoder_output_2: io.AudioEncoderOutput.Type
|
||||
mask: io.Mask.Type
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanInfiniteTalkToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.DynamicCombo.Input("mode", options=[
|
||||
io.DynamicCombo.Option("single_speaker", []),
|
||||
io.DynamicCombo.Option("two_speakers", [
|
||||
io.AudioEncoderOutput.Input("audio_encoder_output_2", optional=True),
|
||||
io.Mask.Input("mask_1", optional=True, tooltip="Mask for the first speaker, required if using two audio inputs."),
|
||||
io.Mask.Input("mask_2", optional=True, tooltip="Mask for the second speaker, required if using two audio inputs."),
|
||||
]),
|
||||
]),
|
||||
io.Model.Input("model"),
|
||||
io.ModelPatch.Input("model_patch"),
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
io.AudioEncoderOutput.Input("audio_encoder_output_1"),
|
||||
io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context."),
|
||||
io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||
io.Image.Input("previous_frames", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(display_name="model"),
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
io.Int.Output(display_name="trim_image"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mode: DCValues, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count,
|
||||
start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None) -> io.NodeOutput:
|
||||
|
||||
if previous_frames is not None and previous_frames.shape[0] < motion_frame_count:
|
||||
raise ValueError("Not enough previous frames provided.")
|
||||
|
||||
if mode["mode"] == "two_speakers":
|
||||
audio_encoder_output_2 = mode["audio_encoder_output_2"]
|
||||
mask_1 = mode["mask_1"]
|
||||
mask_2 = mode["mask_2"]
|
||||
|
||||
if audio_encoder_output_2 is not None:
|
||||
if mask_1 is None or mask_2 is None:
|
||||
raise ValueError("Masks must be provided if two audio encoder outputs are used.")
|
||||
|
||||
ref_masks = None
|
||||
if mask_1 is not None and mask_2 is not None:
|
||||
if audio_encoder_output_2 is None:
|
||||
raise ValueError("Second audio encoder output must be provided if two masks are used.")
|
||||
ref_masks = torch.cat([mask_1, mask_2])
|
||||
|
||||
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
|
||||
image[:start_image.shape[0]] = start_image
|
||||
|
||||
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||
concat_mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
concat_mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
model_patched = model.clone()
|
||||
|
||||
encoded_audio_list = []
|
||||
seq_lengths = []
|
||||
|
||||
for audio_encoder_output in [audio_encoder_output_1, audio_encoder_output_2]:
|
||||
if audio_encoder_output is None:
|
||||
continue
|
||||
all_layers = audio_encoder_output["encoded_audio_all_layers"]
|
||||
encoded_audio = torch.stack(all_layers, dim=0).squeeze(1)[1:] # shape: [num_layers, T, 512]
|
||||
encoded_audio = linear_interpolation(encoded_audio, input_fps=50, output_fps=25).movedim(0, 1) # shape: [T, num_layers, 512]
|
||||
encoded_audio_list.append(encoded_audio)
|
||||
seq_lengths.append(encoded_audio.shape[0])
|
||||
|
||||
# Pad / combine depending on multi_audio_type
|
||||
multi_audio_type = "add"
|
||||
if len(encoded_audio_list) > 1:
|
||||
if multi_audio_type == "para":
|
||||
max_len = max(seq_lengths)
|
||||
padded = []
|
||||
for emb in encoded_audio_list:
|
||||
if emb.shape[0] < max_len:
|
||||
pad = torch.zeros(max_len - emb.shape[0], *emb.shape[1:], dtype=emb.dtype)
|
||||
emb = torch.cat([emb, pad], dim=0)
|
||||
padded.append(emb)
|
||||
encoded_audio_list = padded
|
||||
elif multi_audio_type == "add":
|
||||
total_len = sum(seq_lengths)
|
||||
full_list = []
|
||||
offset = 0
|
||||
for emb, seq_len in zip(encoded_audio_list, seq_lengths):
|
||||
full = torch.zeros(total_len, *emb.shape[1:], dtype=emb.dtype)
|
||||
full[offset:offset+seq_len] = emb
|
||||
full_list.append(full)
|
||||
offset += seq_len
|
||||
encoded_audio_list = full_list
|
||||
|
||||
token_ref_target_masks = None
|
||||
if ref_masks is not None:
|
||||
token_ref_target_masks = torch.nn.functional.interpolate(
|
||||
ref_masks.unsqueeze(0), size=(latent.shape[-2] // 2, latent.shape[-1] // 2), mode='nearest')[0]
|
||||
token_ref_target_masks = (token_ref_target_masks > 0).view(token_ref_target_masks.shape[0], -1)
|
||||
|
||||
# when extending from previous frames
|
||||
if previous_frames is not None:
|
||||
motion_frames = comfy.utils.common_upscale(previous_frames[-motion_frame_count:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
frame_offset = previous_frames.shape[0] - motion_frame_count
|
||||
|
||||
audio_start = frame_offset
|
||||
audio_end = audio_start + length
|
||||
logging.info(f"InfiniteTalk: Processing audio frames {audio_start} - {audio_end}")
|
||||
|
||||
motion_frames_latent = vae.encode(motion_frames[:, :, :, :3])
|
||||
trim_image = motion_frame_count
|
||||
else:
|
||||
audio_start = trim_image = 0
|
||||
audio_end = length
|
||||
motion_frames_latent = concat_latent_image[:, :, :1]
|
||||
|
||||
audio_embed = project_audio_features(model_patch.model.audio_proj, encoded_audio_list, audio_start, audio_end).to(model_patched.model_dtype())
|
||||
model_patched.model_options["transformer_options"]["audio_embeds"] = audio_embed
|
||||
|
||||
# add outer sample wrapper
|
||||
model_patched.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.OUTER_SAMPLE,
|
||||
"infinite_talk_outer_sample",
|
||||
InfiniteTalkOuterSampleWrapper(
|
||||
motion_frames_latent,
|
||||
model_patch,
|
||||
is_extend=previous_frames is not None,
|
||||
))
|
||||
# add cross-attention patch
|
||||
model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale), "attn2_patch")
|
||||
if token_ref_target_masks is not None:
|
||||
model_patched.set_model_patch(MultiTalkGetAttnMapPatch(token_ref_target_masks), "attn1_patch")
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
|
||||
|
||||
|
||||
class WanExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@@ -1307,6 +1473,7 @@ class WanExtension(ComfyExtension):
|
||||
WanHuMoImageToVideo,
|
||||
WanAnimateToVideo,
|
||||
Wan22ImageToVideoLatent,
|
||||
WanInfiniteTalkToVideo,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> WanExtension:
|
||||
|
||||
88
comfy_extras/nodes_zimage.py
Normal file
88
comfy_extras/nodes_zimage.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import node_helpers
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import math
|
||||
import comfy.utils
|
||||
|
||||
|
||||
class TextEncodeZImageOmni(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TextEncodeZImageOmni",
|
||||
category="advanced/conditioning",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.ClipVision.Input("image_encoder", optional=True),
|
||||
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
|
||||
io.Boolean.Input("auto_resize_images", default=True),
|
||||
io.Vae.Input("vae", optional=True),
|
||||
io.Image.Input("image1", optional=True),
|
||||
io.Image.Input("image2", optional=True),
|
||||
io.Image.Input("image3", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, image_encoder=None, auto_resize_images=True, vae=None, image1=None, image2=None, image3=None) -> io.NodeOutput:
|
||||
ref_latents = []
|
||||
images = list(filter(lambda a: a is not None, [image1, image2, image3]))
|
||||
|
||||
prompt_list = []
|
||||
template = None
|
||||
if len(images) > 0:
|
||||
prompt_list = ["<|im_start|>user\n<|vision_start|>"]
|
||||
prompt_list += ["<|vision_end|><|vision_start|>"] * (len(images) - 1)
|
||||
prompt_list += ["<|vision_end|><|im_end|>"]
|
||||
template = "<|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"
|
||||
|
||||
encoded_images = []
|
||||
|
||||
for i, image in enumerate(images):
|
||||
if image_encoder is not None:
|
||||
encoded_images.append(image_encoder.encode_image(image))
|
||||
|
||||
if vae is not None:
|
||||
if auto_resize_images:
|
||||
samples = image.movedim(-1, 1)
|
||||
total = int(1024 * 1024)
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
width = round(samples.shape[3] * scale_by / 8.0) * 8
|
||||
height = round(samples.shape[2] * scale_by / 8.0) * 8
|
||||
|
||||
image = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1)
|
||||
ref_latents.append(vae.encode(image))
|
||||
|
||||
tokens = clip.tokenize(prompt, llama_template=template)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
|
||||
extra_text_embeds = []
|
||||
for p in prompt_list:
|
||||
tokens = clip.tokenize(p, llama_template="{}")
|
||||
text_embeds = clip.encode_from_tokens_scheduled(tokens)
|
||||
extra_text_embeds.append(text_embeds[0][0])
|
||||
|
||||
if len(ref_latents) > 0:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True)
|
||||
if len(encoded_images) > 0:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"clip_vision_outputs": encoded_images}, append=True)
|
||||
if len(extra_text_embeds) > 0:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents_text_embeds": extra_text_embeds}, append=True)
|
||||
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
|
||||
class ZImageExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
TextEncodeZImageOmni,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ZImageExtension:
|
||||
return ZImageExtension()
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.9.1"
|
||||
__version__ = "0.10.0"
|
||||
|
||||
@@ -11,7 +11,7 @@ import logging
|
||||
default_preview_method = args.preview_method
|
||||
|
||||
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
|
||||
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
|
||||
|
||||
def preview_to_image(latent_image, do_scale=True):
|
||||
if do_scale:
|
||||
|
||||
55
nodes.py
55
nodes.py
@@ -5,6 +5,7 @@ import torch
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import glob
|
||||
import hashlib
|
||||
import inspect
|
||||
import traceback
|
||||
@@ -69,6 +70,7 @@ class CLIPTextEncode(ComfyNodeABC):
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."
|
||||
SEARCH_ALIASES = ["text", "prompt", "text prompt", "positive prompt", "negative prompt", "encode text", "text encoder", "encode prompt"]
|
||||
|
||||
def encode(self, clip, text):
|
||||
if clip is None:
|
||||
@@ -85,6 +87,7 @@ class ConditioningCombine:
|
||||
FUNCTION = "combine"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
SEARCH_ALIASES = ["combine", "merge conditioning", "combine prompts", "merge prompts", "mix prompts", "add prompt"]
|
||||
|
||||
def combine(self, conditioning_1, conditioning_2):
|
||||
return (conditioning_1 + conditioning_2, )
|
||||
@@ -293,6 +296,7 @@ class VAEDecode:
|
||||
|
||||
CATEGORY = "latent"
|
||||
DESCRIPTION = "Decodes latent images back into pixel space images."
|
||||
SEARCH_ALIASES = ["decode", "decode latent", "latent to image", "render latent"]
|
||||
|
||||
def decode(self, vae, samples):
|
||||
latent = samples["samples"]
|
||||
@@ -345,6 +349,7 @@ class VAEEncode:
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "latent"
|
||||
SEARCH_ALIASES = ["encode", "encode image", "image to latent"]
|
||||
|
||||
def encode(self, vae, pixels):
|
||||
t = vae.encode(pixels)
|
||||
@@ -580,6 +585,7 @@ class CheckpointLoaderSimple:
|
||||
|
||||
CATEGORY = "loaders"
|
||||
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
|
||||
SEARCH_ALIASES = ["load model", "checkpoint", "model loader", "load checkpoint", "ckpt", "model"]
|
||||
|
||||
def load_checkpoint(self, ckpt_name):
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
@@ -666,6 +672,7 @@ class LoraLoader:
|
||||
|
||||
CATEGORY = "loaders"
|
||||
DESCRIPTION = "LoRAs are used to modify diffusion and CLIP models, altering the way in which latents are denoised such as applying styles. Multiple LoRA nodes can be linked together."
|
||||
SEARCH_ALIASES = ["lora", "load lora", "apply lora", "lora loader", "lora model"]
|
||||
|
||||
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
|
||||
if strength_model == 0 and strength_clip == 0:
|
||||
@@ -700,7 +707,7 @@ class LoraLoaderModelOnly(LoraLoader):
|
||||
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
|
||||
|
||||
class VAELoader:
|
||||
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
|
||||
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
|
||||
image_taes = ["taesd", "taesdxl", "taesd3", "taef1"]
|
||||
@staticmethod
|
||||
def vae_list(s):
|
||||
@@ -788,6 +795,7 @@ class VAELoader:
|
||||
|
||||
#TODO: scale factor?
|
||||
def load_vae(self, vae_name):
|
||||
metadata = None
|
||||
if vae_name == "pixel_space":
|
||||
sd = {}
|
||||
sd["pixel_space_vae"] = torch.tensor(1.0)
|
||||
@@ -798,8 +806,8 @@ class VAELoader:
|
||||
vae_path = folder_paths.get_full_path_or_raise("vae_approx", vae_name)
|
||||
else:
|
||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
vae = comfy.sd.VAE(sd=sd)
|
||||
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
|
||||
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
|
||||
vae.throw_exception_if_invalid()
|
||||
return (vae,)
|
||||
|
||||
@@ -812,6 +820,7 @@ class ControlNetLoader:
|
||||
FUNCTION = "load_controlnet"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
SEARCH_ALIASES = ["controlnet", "control net", "cn", "load controlnet", "controlnet loader"]
|
||||
|
||||
def load_controlnet(self, control_net_name):
|
||||
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
|
||||
@@ -888,6 +897,7 @@ class ControlNetApplyAdvanced:
|
||||
FUNCTION = "apply_controlnet"
|
||||
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
SEARCH_ALIASES = ["controlnet", "apply controlnet", "use controlnet", "control net"]
|
||||
|
||||
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]):
|
||||
if strength == 0:
|
||||
@@ -1198,6 +1208,7 @@ class EmptyLatentImage:
|
||||
|
||||
CATEGORY = "latent"
|
||||
DESCRIPTION = "Create a new batch of empty latent images to be denoised via sampling."
|
||||
SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"]
|
||||
|
||||
def generate(self, width, height, batch_size=1):
|
||||
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
|
||||
@@ -1538,6 +1549,7 @@ class KSampler:
|
||||
|
||||
CATEGORY = "sampling"
|
||||
DESCRIPTION = "Uses the provided model, positive and negative conditioning to denoise the latent image."
|
||||
SEARCH_ALIASES = ["sampler", "sample", "generate", "denoise", "diffuse", "txt2img", "img2img"]
|
||||
|
||||
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
|
||||
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
|
||||
@@ -1602,6 +1614,7 @@ class SaveImage:
|
||||
|
||||
CATEGORY = "image"
|
||||
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
|
||||
SEARCH_ALIASES = ["save", "save image", "export image", "output image", "write image", "download"]
|
||||
|
||||
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
@@ -1638,6 +1651,8 @@ class PreviewImage(SaveImage):
|
||||
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
|
||||
self.compress_level = 1
|
||||
|
||||
SEARCH_ALIASES = ["preview", "preview image", "show image", "view image", "display image", "image viewer"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
@@ -1656,6 +1671,7 @@ class LoadImage:
|
||||
}
|
||||
|
||||
CATEGORY = "image"
|
||||
SEARCH_ALIASES = ["load image", "open image", "import image", "image input", "upload image", "read image", "image loader"]
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
FUNCTION = "load_image"
|
||||
@@ -1808,6 +1824,7 @@ class ImageScale:
|
||||
FUNCTION = "upscale"
|
||||
|
||||
CATEGORY = "image/upscaling"
|
||||
SEARCH_ALIASES = ["resize", "resize image", "scale image", "image resize", "zoom", "zoom in", "change size"]
|
||||
|
||||
def upscale(self, image, upscale_method, width, height, crop):
|
||||
if width == 0 and height == 0:
|
||||
@@ -2371,6 +2388,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_kandinsky5.py",
|
||||
"nodes_wanmove.py",
|
||||
"nodes_image_compare.py",
|
||||
"nodes_zimage.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
@@ -2383,37 +2401,12 @@ async def init_builtin_extra_nodes():
|
||||
|
||||
async def init_builtin_api_nodes():
|
||||
api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes")
|
||||
api_nodes_files = [
|
||||
"nodes_ideogram.py",
|
||||
"nodes_openai.py",
|
||||
"nodes_minimax.py",
|
||||
"nodes_veo2.py",
|
||||
"nodes_kling.py",
|
||||
"nodes_bfl.py",
|
||||
"nodes_bytedance.py",
|
||||
"nodes_ltxv.py",
|
||||
"nodes_luma.py",
|
||||
"nodes_recraft.py",
|
||||
"nodes_pixverse.py",
|
||||
"nodes_stability.py",
|
||||
"nodes_runway.py",
|
||||
"nodes_sora.py",
|
||||
"nodes_topaz.py",
|
||||
"nodes_tripo.py",
|
||||
"nodes_moonvalley.py",
|
||||
"nodes_rodin.py",
|
||||
"nodes_gemini.py",
|
||||
"nodes_vidu.py",
|
||||
"nodes_wan.py",
|
||||
]
|
||||
|
||||
if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):
|
||||
return api_nodes_files
|
||||
api_nodes_files = sorted(glob.glob(os.path.join(api_nodes_dir, "nodes_*.py")))
|
||||
|
||||
import_failed = []
|
||||
for node_file in api_nodes_files:
|
||||
if not await load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"):
|
||||
import_failed.append(node_file)
|
||||
if not await load_custom_node(node_file, module_parent="comfy_api_nodes"):
|
||||
import_failed.append(os.path.basename(node_file))
|
||||
|
||||
return import_failed
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.9.1"
|
||||
version = "0.10.0"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.36.14
|
||||
comfyui-workflow-templates==0.8.4
|
||||
comfyui-frontend-package==1.37.11
|
||||
comfyui-workflow-templates==0.8.15
|
||||
comfyui-embedded-docs==0.4.0
|
||||
torch
|
||||
torchsde
|
||||
@@ -21,7 +21,7 @@ psutil
|
||||
alembic
|
||||
SQLAlchemy
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.6
|
||||
comfy-kitchen>=0.2.7
|
||||
|
||||
#non essential dependencies:
|
||||
kornia>=0.7.1
|
||||
|
||||
10
server.py
10
server.py
@@ -40,6 +40,7 @@ from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from app.subgraph_manager import SubgraphManager
|
||||
from app.node_replace_manager import NodeReplaceManager
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from protocol import BinaryEventTypes
|
||||
@@ -204,6 +205,7 @@ class PromptServer():
|
||||
self.model_file_manager = ModelFileManager()
|
||||
self.custom_node_manager = CustomNodeManager()
|
||||
self.subgraph_manager = SubgraphManager()
|
||||
self.node_replace_manager = NodeReplaceManager()
|
||||
self.internal_routes = InternalRoutes(self)
|
||||
self.supports = ["custom_nodes_from_web"]
|
||||
self.prompt_queue = execution.PromptQueue(self)
|
||||
@@ -682,11 +684,16 @@ class PromptServer():
|
||||
|
||||
if hasattr(obj_class, 'API_NODE'):
|
||||
info['api_node'] = obj_class.API_NODE
|
||||
|
||||
info['search_aliases'] = getattr(obj_class, 'SEARCH_ALIASES', [])
|
||||
return info
|
||||
|
||||
@routes.get("/object_info")
|
||||
async def get_object_info(request):
|
||||
seed_assets(["models"])
|
||||
try:
|
||||
seed_assets(["models"])
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to seed assets: {e}")
|
||||
with folder_paths.cache_helper:
|
||||
out = {}
|
||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||
@@ -987,6 +994,7 @@ class PromptServer():
|
||||
self.model_file_manager.add_routes(self.routes)
|
||||
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
||||
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
|
||||
self.node_replace_manager.add_routes(self.routes)
|
||||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||
|
||||
# Prefix every route with /api for easier matching for delegation.
|
||||
|
||||
@@ -1,297 +0,0 @@
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||
|
||||
|
||||
def test_model_field_to_float_input():
|
||||
"""Tests mapping a float field with constraints."""
|
||||
|
||||
class ModelWithFloatField(BaseModel):
|
||||
cfg_scale: Optional[float] = Field(
|
||||
default=0.5,
|
||||
description="Flexibility in video generation",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
multiple_of=0.001,
|
||||
)
|
||||
|
||||
expected_output = (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.5,
|
||||
"tooltip": "Flexibility in video generation",
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.001,
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(
|
||||
IO.FLOAT, ModelWithFloatField, "cfg_scale"
|
||||
)
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_float_input_no_constraints():
|
||||
"""Tests mapping a float field with no constraints."""
|
||||
|
||||
class ModelWithFloatField(BaseModel):
|
||||
cfg_scale: Optional[float] = Field(default=0.5)
|
||||
|
||||
expected_output = (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.5,
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(
|
||||
IO.FLOAT, ModelWithFloatField, "cfg_scale"
|
||||
)
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_int_input():
|
||||
"""Tests mapping an int field with constraints."""
|
||||
|
||||
class ModelWithIntField(BaseModel):
|
||||
num_frames: Optional[int] = Field(
|
||||
default=10,
|
||||
description="Number of frames to generate",
|
||||
ge=1,
|
||||
le=100,
|
||||
multiple_of=1,
|
||||
)
|
||||
|
||||
expected_output = (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 10,
|
||||
"tooltip": "Number of frames to generate",
|
||||
"min": 1,
|
||||
"max": 100,
|
||||
"step": 1,
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(IO.INT, ModelWithIntField, "num_frames")
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_string_input():
|
||||
"""Tests mapping a string field."""
|
||||
|
||||
class ModelWithStringField(BaseModel):
|
||||
prompt: Optional[str] = Field(
|
||||
default="A beautiful sunset over a calm ocean",
|
||||
description="A prompt for the video generation",
|
||||
)
|
||||
|
||||
expected_output = (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "A beautiful sunset over a calm ocean",
|
||||
"tooltip": "A prompt for the video generation",
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(IO.STRING, ModelWithStringField, "prompt")
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_string_input_multiline():
|
||||
"""Tests mapping a string field."""
|
||||
|
||||
class ModelWithStringField(BaseModel):
|
||||
prompt: Optional[str] = Field(
|
||||
default="A beautiful sunset over a calm ocean",
|
||||
description="A prompt for the video generation",
|
||||
)
|
||||
|
||||
expected_output = (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "A beautiful sunset over a calm ocean",
|
||||
"tooltip": "A prompt for the video generation",
|
||||
"multiline": True,
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(
|
||||
IO.STRING, ModelWithStringField, "prompt", multiline=True
|
||||
)
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_combo_input():
|
||||
"""Tests mapping a combo field."""
|
||||
|
||||
class MockEnum(str, Enum):
|
||||
option_1 = "option 1"
|
||||
option_2 = "option 2"
|
||||
option_3 = "option 3"
|
||||
|
||||
class ModelWithComboField(BaseModel):
|
||||
model_name: Optional[MockEnum] = Field("option 1", description="Model Name")
|
||||
|
||||
expected_output = (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["option 1", "option 2", "option 3"],
|
||||
"default": "option 1",
|
||||
"tooltip": "Model Name",
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(
|
||||
IO.COMBO, ModelWithComboField, "model_name", enum_type=MockEnum
|
||||
)
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_combo_input_no_options():
|
||||
"""Tests mapping a combo field with no options."""
|
||||
|
||||
class ModelWithComboField(BaseModel):
|
||||
model_name: Optional[str] = Field(description="Model Name")
|
||||
|
||||
expected_output = (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "Model Name",
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(
|
||||
IO.COMBO, ModelWithComboField, "model_name"
|
||||
)
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_image_input():
|
||||
"""Tests mapping an image field."""
|
||||
|
||||
class ModelWithImageField(BaseModel):
|
||||
image: Optional[str] = Field(
|
||||
default=None,
|
||||
description="An image for the video generation",
|
||||
)
|
||||
|
||||
expected_output = (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "An image for the video generation",
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(IO.IMAGE, ModelWithImageField, "image")
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_node_input_no_description():
|
||||
"""Tests mapping a field with no description."""
|
||||
|
||||
class ModelWithNoDescriptionField(BaseModel):
|
||||
field: Optional[str] = Field(default="default value")
|
||||
|
||||
expected_output = (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "default value",
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(
|
||||
IO.STRING, ModelWithNoDescriptionField, "field"
|
||||
)
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_node_input_no_default():
|
||||
"""Tests mapping a field with no default."""
|
||||
|
||||
class ModelWithNoDefaultField(BaseModel):
|
||||
field: Optional[str] = Field(description="A field with no default")
|
||||
|
||||
expected_output = (
|
||||
IO.STRING,
|
||||
{
|
||||
"tooltip": "A field with no default",
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(
|
||||
IO.STRING, ModelWithNoDefaultField, "field"
|
||||
)
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_node_input_no_metadata():
|
||||
"""Tests mapping a field with no metadata or properties defined on the schema."""
|
||||
|
||||
class ModelWithNoMetadataField(BaseModel):
|
||||
field: Optional[str] = Field()
|
||||
|
||||
expected_output = (
|
||||
IO.STRING,
|
||||
{},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(
|
||||
IO.STRING, ModelWithNoMetadataField, "field"
|
||||
)
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
|
||||
|
||||
def test_model_field_to_node_input_default_is_none():
|
||||
"""
|
||||
Tests mapping a field with a default of `None`.
|
||||
I.e., the default field should be included as the schema explicitly sets it to `None`.
|
||||
"""
|
||||
|
||||
class ModelWithNoneDefaultField(BaseModel):
|
||||
field: Optional[str] = Field(
|
||||
default=None, description="A field with a default of None"
|
||||
)
|
||||
|
||||
expected_output = (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "A field with a default of None",
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = model_field_to_node_input(
|
||||
IO.STRING, ModelWithNoneDefaultField, "field"
|
||||
)
|
||||
|
||||
assert actual_output[0] == expected_output[0]
|
||||
assert actual_output[1] == expected_output[1]
|
||||
Reference in New Issue
Block a user