mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-18 14:10:07 +00:00
Compare commits
1 Commits
automation
...
painter-no
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2c5d49efd1 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -11,7 +11,7 @@ extra_model_paths.yaml
|
||||
/.vs
|
||||
.vscode/
|
||||
.idea/
|
||||
venv*/
|
||||
venv/
|
||||
.venv/
|
||||
/web/extensions/*
|
||||
!/web/extensions/logging.js.example
|
||||
|
||||
@@ -1,105 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
if TYPE_CHECKING:
|
||||
from comfy_api.latest._io_public import NodeReplace
|
||||
|
||||
from comfy_execution.graph_utils import is_link
|
||||
import nodes
|
||||
|
||||
class NodeStruct(TypedDict):
|
||||
inputs: dict[str, str | int | float | bool | tuple[str, int]]
|
||||
class_type: str
|
||||
_meta: dict[str, str]
|
||||
|
||||
def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
|
||||
new_node_struct = node_struct.copy()
|
||||
if empty_inputs:
|
||||
new_node_struct["inputs"] = {}
|
||||
else:
|
||||
new_node_struct["inputs"] = node_struct["inputs"].copy()
|
||||
new_node_struct["_meta"] = node_struct["_meta"].copy()
|
||||
return new_node_struct
|
||||
|
||||
|
||||
class NodeReplaceManager:
|
||||
"""Manages node replacement registrations."""
|
||||
|
||||
def __init__(self):
|
||||
self._replacements: dict[str, list[NodeReplace]] = {}
|
||||
|
||||
def register(self, node_replace: NodeReplace):
|
||||
"""Register a node replacement mapping."""
|
||||
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
|
||||
|
||||
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
|
||||
"""Get replacements for an old node ID."""
|
||||
return self._replacements.get(old_node_id)
|
||||
|
||||
def has_replacement(self, old_node_id: str) -> bool:
|
||||
"""Check if a replacement exists for an old node ID."""
|
||||
return old_node_id in self._replacements
|
||||
|
||||
def apply_replacements(self, prompt: dict[str, NodeStruct]):
|
||||
connections: dict[str, list[tuple[str, str, int]]] = {}
|
||||
need_replacement: set[str] = set()
|
||||
for node_number, node_struct in prompt.items():
|
||||
class_type = node_struct["class_type"]
|
||||
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
|
||||
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
|
||||
need_replacement.add(node_number)
|
||||
# keep track of connections
|
||||
for input_id, input_value in node_struct["inputs"].items():
|
||||
if is_link(input_value):
|
||||
conn_number = input_value[0]
|
||||
connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
|
||||
for node_number in need_replacement:
|
||||
node_struct = prompt[node_number]
|
||||
class_type = node_struct["class_type"]
|
||||
replacements = self.get_replacement(class_type)
|
||||
if replacements is None:
|
||||
continue
|
||||
# just use the first replacement
|
||||
replacement = replacements[0]
|
||||
new_node_id = replacement.new_node_id
|
||||
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
|
||||
if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
|
||||
continue
|
||||
# first, replace node id (class_type)
|
||||
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
|
||||
new_node_struct["class_type"] = new_node_id
|
||||
# TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
|
||||
# second, replace inputs
|
||||
if replacement.input_mapping is not None:
|
||||
for input_map in replacement.input_mapping:
|
||||
if "set_value" in input_map:
|
||||
new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
|
||||
elif "old_id" in input_map:
|
||||
new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
|
||||
# finalize input replacement
|
||||
prompt[node_number] = new_node_struct
|
||||
# third, replace outputs
|
||||
if replacement.output_mapping is not None:
|
||||
# re-mapping outputs requires changing the input values of nodes that receive connections from this one
|
||||
if node_number in connections:
|
||||
for conns in connections[node_number]:
|
||||
conn_node_number, conn_input_id, old_output_idx = conns
|
||||
for output_map in replacement.output_mapping:
|
||||
if output_map["old_idx"] == old_output_idx:
|
||||
new_output_idx = output_map["new_idx"]
|
||||
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
|
||||
previous_input[1] = new_output_idx
|
||||
|
||||
def as_dict(self):
|
||||
"""Serialize all replacements to dict."""
|
||||
return {
|
||||
k: [v.as_dict() for v in v_list]
|
||||
for k, v_list in self._replacements.items()
|
||||
}
|
||||
|
||||
def add_routes(self, routes):
|
||||
@routes.get("/node_replacements")
|
||||
async def get_node_replacements(request):
|
||||
return web.json_response(self.as_dict())
|
||||
13
comfy/checkpoint_pickle.py
Normal file
13
comfy/checkpoint_pickle.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import pickle
|
||||
|
||||
load = pickle.load
|
||||
|
||||
class Empty:
|
||||
pass
|
||||
|
||||
class Unpickler(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
#TODO: safe unpickle
|
||||
if module.startswith("pytorch_lightning"):
|
||||
return Empty
|
||||
return super().find_class(module, name)
|
||||
@@ -179,8 +179,8 @@ class LLMAdapter(nn.Module):
|
||||
if source_attention_mask.ndim == 2:
|
||||
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
x = self.in_proj(self.embed(target_input_ids))
|
||||
context = source_hidden_states
|
||||
x = self.in_proj(self.embed(target_input_ids, out_dtype=context.dtype))
|
||||
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
|
||||
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
|
||||
position_embeddings = self.rotary_emb(x, position_ids)
|
||||
|
||||
@@ -152,7 +152,6 @@ class Chroma(nn.Module):
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
transformer_options = transformer_options.copy()
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
# running on sequences img
|
||||
@@ -229,7 +228,6 @@ class Chroma(nn.Module):
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if i not in self.skip_dit:
|
||||
|
||||
@@ -196,9 +196,6 @@ class DoubleStreamBlock(nn.Module):
|
||||
else:
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
extra_options = transformer_options.copy()
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||
@@ -227,12 +224,6 @@ class DoubleStreamBlock(nn.Module):
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
for p in patch:
|
||||
attn = p(attn, extra_options)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
@@ -312,9 +303,6 @@ class SingleStreamBlock(nn.Module):
|
||||
else:
|
||||
mod = vec
|
||||
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
extra_options = transformer_options.copy()
|
||||
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
@@ -324,12 +312,6 @@ class SingleStreamBlock(nn.Module):
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
for p in patch:
|
||||
attn = p(attn, extra_options)
|
||||
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
if self.yak_mlp:
|
||||
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
||||
|
||||
@@ -142,7 +142,6 @@ class Flux(nn.Module):
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
|
||||
transformer_options = transformer_options.copy()
|
||||
patches = transformer_options.get("patches", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
@@ -232,7 +231,6 @@ class Flux(nn.Module):
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
|
||||
@@ -304,7 +304,6 @@ class HunyuanVideo(nn.Module):
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
transformer_options = transformer_options.copy()
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
initial_shape = list(img.shape)
|
||||
@@ -417,7 +416,6 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
|
||||
@@ -102,7 +102,19 @@ class VideoConv3d(nn.Module):
|
||||
return self.conv(x)
|
||||
|
||||
def interpolate_up(x, scale_factor):
|
||||
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
|
||||
try:
|
||||
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
|
||||
except: #operation not implemented for bf16
|
||||
orig_shape = list(x.shape)
|
||||
out_shape = orig_shape[:2]
|
||||
for i in range(len(orig_shape) - 2):
|
||||
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
|
||||
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
|
||||
split = 8
|
||||
l = out.shape[1] // split
|
||||
for i in range(0, out.shape[1], l):
|
||||
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
|
||||
return out
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
|
||||
|
||||
@@ -374,31 +374,6 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
|
||||
|
||||
return padded_tensor
|
||||
|
||||
def calculate_shape(patches, weight, key, original_weights=None):
|
||||
current_shape = weight.shape
|
||||
|
||||
for p in patches:
|
||||
v = p[1]
|
||||
offset = p[3]
|
||||
|
||||
# Offsets restore the old shape; lists force a diff without metadata
|
||||
if offset is not None or isinstance(v, list):
|
||||
continue
|
||||
|
||||
if isinstance(v, weight_adapter.WeightAdapterBase):
|
||||
adapter_shape = v.calculate_shape(key)
|
||||
if adapter_shape is not None:
|
||||
current_shape = adapter_shape
|
||||
continue
|
||||
|
||||
# Standard diff logic with padding
|
||||
if len(v) == 2:
|
||||
patch_type, patch_data = v[0], v[1]
|
||||
if patch_type == "diff" and len(patch_data) > 1 and patch_data[1]['pad_weight']:
|
||||
current_shape = patch_data[0].shape
|
||||
|
||||
return current_shape
|
||||
|
||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
|
||||
for p in patches:
|
||||
strength = p[0]
|
||||
|
||||
@@ -178,7 +178,10 @@ class BaseModel(torch.nn.Module):
|
||||
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
|
||||
|
||||
context = c_crossattn
|
||||
dtype = self.get_dtype_inference()
|
||||
dtype = self.get_dtype()
|
||||
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
|
||||
xc = xc.to(dtype)
|
||||
device = xc.device
|
||||
@@ -215,13 +218,6 @@ class BaseModel(torch.nn.Module):
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
|
||||
def get_dtype_inference(self):
|
||||
dtype = self.get_dtype()
|
||||
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
return dtype
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
|
||||
@@ -376,7 +372,9 @@ class BaseModel(torch.nn.Module):
|
||||
input_shapes += shape
|
||||
|
||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||
dtype = self.get_dtype_inference()
|
||||
dtype = self.get_dtype()
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
#TODO: this needs to be tweaked
|
||||
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
||||
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
||||
@@ -1167,7 +1165,7 @@ class Anima(BaseModel):
|
||||
t5xxl_ids = t5xxl_ids.unsqueeze(0)
|
||||
|
||||
if torch.is_inference_mode_enabled(): # if not we are training
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference()))
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
|
||||
else:
|
||||
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
|
||||
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
||||
|
||||
@@ -406,16 +406,13 @@ class ModelPatcher:
|
||||
def memory_required(self, input_shape):
|
||||
return self.model.memory_required(input_shape=input_shape)
|
||||
|
||||
def disable_model_cfg1_optimization(self):
|
||||
self.model_options["disable_cfg1_optimization"] = True
|
||||
|
||||
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||
else:
|
||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||
if disable_cfg1_optimization:
|
||||
self.disable_model_cfg1_optimization()
|
||||
self.model_options["disable_cfg1_optimization"] = True
|
||||
|
||||
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
||||
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
||||
@@ -1517,10 +1514,8 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
if weight is None:
|
||||
return (False, 0)
|
||||
return 0
|
||||
if key in self.patches:
|
||||
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
|
||||
return (True, 0)
|
||||
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
|
||||
num_patches += 1
|
||||
else:
|
||||
@@ -1534,13 +1529,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
|
||||
weight._model_dtype = model_dtype
|
||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||
return (False, comfy.memory_management.vram_aligned_size(geometry))
|
||||
|
||||
def force_load_param(self, param_key, device_to):
|
||||
key = key_param_name_to_key(n, param_key)
|
||||
if key in self.backup:
|
||||
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
|
||||
self.patch_weight_to_device(key, device_to=device_to)
|
||||
return comfy.memory_management.vram_aligned_size(geometry)
|
||||
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.comfy_cast_weights = True
|
||||
@@ -1548,19 +1537,13 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
m.seed_key = n
|
||||
set_dirty(m, dirty)
|
||||
|
||||
force_load, v_weight_size = setup_param(self, m, n, "weight")
|
||||
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
|
||||
force_load = force_load or force_load_bias
|
||||
v_weight_size += v_weight_bias
|
||||
v_weight_size = 0
|
||||
v_weight_size += setup_param(self, m, n, "weight")
|
||||
v_weight_size += setup_param(self, m, n, "bias")
|
||||
|
||||
if force_load:
|
||||
logging.info(f"Module {n} has resizing Lora - force loading")
|
||||
force_load_param(self, "weight", device_to)
|
||||
force_load_param(self, "bias", device_to)
|
||||
else:
|
||||
if vbar is not None and not hasattr(m, "_v"):
|
||||
m._v = vbar.alloc(v_weight_size)
|
||||
allocated_size += v_weight_size
|
||||
if vbar is not None and not hasattr(m, "_v"):
|
||||
m._v = vbar.alloc(v_weight_size)
|
||||
allocated_size += v_weight_size
|
||||
|
||||
else:
|
||||
for param in params:
|
||||
@@ -1623,11 +1606,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
for m in self.model.modules():
|
||||
move_weight_functions(m, device_to)
|
||||
|
||||
keys = list(self.backup.keys())
|
||||
for k in keys:
|
||||
bk = self.backup[k]
|
||||
comfy.utils.set_attr_param(self.model, k, bk.weight)
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
assert not force_patch_weights #See above
|
||||
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
||||
|
||||
23
comfy/ops.py
23
comfy/ops.py
@@ -21,6 +21,7 @@ import logging
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import json
|
||||
import comfy.memory_management
|
||||
import comfy.pinned_memory
|
||||
@@ -79,7 +80,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
|
||||
offload_stream = None
|
||||
xfer_dest = None
|
||||
|
||||
@@ -170,10 +171,10 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
|
||||
x = lowvram_fn(x)
|
||||
if (isinstance(orig, QuantizedTensor) and
|
||||
(want_requant and len(fns) == 0 or update_weight)):
|
||||
(orig.dtype == dtype and len(fns) == 0 or update_weight)):
|
||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||
if want_requant and len(fns) == 0:
|
||||
if orig.dtype == dtype and len(fns) == 0:
|
||||
#The layer actually wants our freshly saved QT
|
||||
x = y
|
||||
elif update_weight:
|
||||
@@ -194,7 +195,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
|
||||
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||
# will add async-offload support to your cast and improve performance.
|
||||
@@ -212,7 +213,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
if hasattr(s, "_v"):
|
||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
|
||||
|
||||
if offloadable and (device != s.weight.device or
|
||||
(s.bias is not None and device != s.bias.device)):
|
||||
@@ -462,7 +463,7 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class RMSNorm(torch.nn.RMSNorm, CastWeightBiasOp):
|
||||
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
self.bias = None
|
||||
return None
|
||||
@@ -474,7 +475,8 @@ class disable_weight_init:
|
||||
weight = None
|
||||
bias = None
|
||||
offload_stream = None
|
||||
x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
||||
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
@@ -850,8 +852,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
|
||||
def forward_comfy_cast_weights(self, input, compute_dtype=None):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
|
||||
x = self._forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
@@ -881,7 +883,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||
|
||||
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
|
||||
|
||||
output = self.forward_comfy_cast_weights(input, compute_dtype)
|
||||
|
||||
# Reshape output back to 3D if input was 3D
|
||||
if reshaped_3d:
|
||||
|
||||
@@ -1,10 +1,57 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import numbers
|
||||
import logging
|
||||
|
||||
RMSNorm = None
|
||||
|
||||
try:
|
||||
rms_norm_torch = torch.nn.functional.rms_norm
|
||||
RMSNorm = torch.nn.RMSNorm
|
||||
except:
|
||||
rms_norm_torch = None
|
||||
logging.warning("Please update pytorch to use native RMSNorm")
|
||||
|
||||
RMSNorm = torch.nn.RMSNorm
|
||||
|
||||
def rms_norm(x, weight=None, eps=1e-6):
|
||||
if weight is None:
|
||||
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
|
||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
if weight is None:
|
||||
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
||||
else:
|
||||
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
else:
|
||||
return torch.nn.functional.rms_norm(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||
if weight is None:
|
||||
return r
|
||||
else:
|
||||
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
if RMSNorm is None:
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape,
|
||||
eps=1e-6,
|
||||
elementwise_affine=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
# mypy error: incompatible types in assignment
|
||||
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
||||
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||
)
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.bias = None
|
||||
|
||||
def forward(self, x):
|
||||
return rms_norm(x, self.weight, self.eps)
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
import torch
|
||||
import math
|
||||
import struct
|
||||
import comfy.memory_management
|
||||
import comfy.checkpoint_pickle
|
||||
import safetensors.torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@@ -38,26 +38,26 @@ import warnings
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
DISABLE_MMAP = args.disable_mmap
|
||||
|
||||
|
||||
if True: # ckpt/pt file whitelist for safe loading of old sd files
|
||||
ALWAYS_SAFE_LOAD = False
|
||||
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
||||
class ModelCheckpoint:
|
||||
pass
|
||||
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
||||
|
||||
def scalar(*args, **kwargs):
|
||||
return None
|
||||
from numpy.core.multiarray import scalar as sc
|
||||
return sc(*args, **kwargs)
|
||||
scalar.__module__ = "numpy.core.multiarray"
|
||||
|
||||
from numpy import dtype
|
||||
from numpy.dtypes import Float64DType
|
||||
|
||||
def encode(*args, **kwargs): # no longer necessary on newer torch
|
||||
return None
|
||||
encode.__module__ = "_codecs"
|
||||
from _codecs import encode
|
||||
|
||||
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
|
||||
ALWAYS_SAFE_LOAD = True
|
||||
logging.info("Checkpoint files will always be loaded safely.")
|
||||
|
||||
else:
|
||||
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
|
||||
|
||||
# Current as of safetensors 0.7.0
|
||||
_TYPES = {
|
||||
@@ -140,8 +140,11 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if MMAP_TORCH_FILES:
|
||||
torch_args["mmap"] = True
|
||||
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||
|
||||
if safe_load or ALWAYS_SAFE_LOAD:
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||
else:
|
||||
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
|
||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
||||
if "state_dict" in pl_sd:
|
||||
sd = pl_sd["state_dict"]
|
||||
else:
|
||||
|
||||
@@ -49,12 +49,6 @@ class WeightAdapterBase:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def calculate_shape(
|
||||
self,
|
||||
key
|
||||
):
|
||||
return None
|
||||
|
||||
def calculate_weight(
|
||||
self,
|
||||
weight,
|
||||
|
||||
@@ -214,13 +214,6 @@ class LoRAAdapter(WeightAdapterBase):
|
||||
else:
|
||||
return None
|
||||
|
||||
def calculate_shape(
|
||||
self,
|
||||
key
|
||||
):
|
||||
reshape = self.weights[5]
|
||||
return tuple(reshape) if reshape is not None else None
|
||||
|
||||
def calculate_weight(
|
||||
self,
|
||||
weight,
|
||||
|
||||
@@ -14,7 +14,6 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
|
||||
"supports_preview_metadata": True,
|
||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||
"extension": {"manager": {"supports_v4": True}},
|
||||
"node_replacements": True,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -21,17 +21,6 @@ class ComfyAPI_latest(ComfyAPIBase):
|
||||
VERSION = "latest"
|
||||
STABLE = False
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.node_replacement = self.NodeReplacement()
|
||||
self.execution = self.Execution()
|
||||
|
||||
class NodeReplacement(ProxiedSingleton):
|
||||
async def register(self, node_replace: io.NodeReplace) -> None:
|
||||
"""Register a node replacement mapping."""
|
||||
from server import PromptServer
|
||||
PromptServer.instance.node_replace_manager.register(node_replace)
|
||||
|
||||
class Execution(ProxiedSingleton):
|
||||
async def set_progress(
|
||||
self,
|
||||
@@ -84,6 +73,8 @@ class ComfyAPI_latest(ComfyAPIBase):
|
||||
image=to_display,
|
||||
)
|
||||
|
||||
execution: Execution
|
||||
|
||||
class ComfyExtension(ABC):
|
||||
async def on_load(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -75,12 +75,6 @@ class NumberDisplay(str, Enum):
|
||||
slider = "slider"
|
||||
|
||||
|
||||
class ControlAfterGenerate(str, Enum):
|
||||
fixed = "fixed"
|
||||
increment = "increment"
|
||||
decrement = "decrement"
|
||||
randomize = "randomize"
|
||||
|
||||
class _ComfyType(ABC):
|
||||
Type = Any
|
||||
io_type: str = None
|
||||
@@ -269,7 +263,7 @@ class Int(ComfyTypeIO):
|
||||
class Input(WidgetInput):
|
||||
'''Integer input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool | ControlAfterGenerate=None,
|
||||
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
||||
self.min = min
|
||||
@@ -351,7 +345,7 @@ class Combo(ComfyTypeIO):
|
||||
tooltip: str=None,
|
||||
lazy: bool=None,
|
||||
default: str | int | Enum = None,
|
||||
control_after_generate: bool | ControlAfterGenerate=None,
|
||||
control_after_generate: bool=None,
|
||||
upload: UploadType=None,
|
||||
image_folder: FolderType=None,
|
||||
remote: RemoteOptions=None,
|
||||
@@ -395,7 +389,7 @@ class MultiCombo(ComfyTypeI):
|
||||
Type = list[str]
|
||||
class Input(Combo.Input):
|
||||
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool | ControlAfterGenerate=None,
|
||||
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
|
||||
socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced)
|
||||
self.multiselect = True
|
||||
@@ -2036,74 +2030,11 @@ class _UIOutput(ABC):
|
||||
...
|
||||
|
||||
|
||||
class InputMapOldId(TypedDict):
|
||||
"""Map an old node input to a new node input by ID."""
|
||||
new_id: str
|
||||
old_id: str
|
||||
|
||||
class InputMapSetValue(TypedDict):
|
||||
"""Set a specific value for a new node input."""
|
||||
new_id: str
|
||||
set_value: Any
|
||||
|
||||
InputMap = InputMapOldId | InputMapSetValue
|
||||
"""
|
||||
Input mapping for node replacement. Type is inferred by dictionary keys:
|
||||
- {"new_id": str, "old_id": str} - maps old input to new input
|
||||
- {"new_id": str, "set_value": Any} - sets a specific value for new input
|
||||
"""
|
||||
|
||||
class OutputMap(TypedDict):
|
||||
"""Map outputs of node replacement via indexes."""
|
||||
new_idx: int
|
||||
old_idx: int
|
||||
|
||||
class NodeReplace:
|
||||
"""
|
||||
Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
|
||||
|
||||
Also supports assigning specific values to the input widgets of the new node.
|
||||
|
||||
Args:
|
||||
new_node_id: The class name of the new replacement node.
|
||||
old_node_id: The class name of the deprecated node.
|
||||
old_widget_ids: Ordered list of input IDs for widgets that may not have an input slot
|
||||
connected. The workflow JSON stores widget values by their relative position index,
|
||||
not by ID. This list maps those positional indexes to input IDs, enabling the
|
||||
replacement system to correctly identify widget values during node migration.
|
||||
input_mapping: List of input mappings from old node to new node.
|
||||
output_mapping: List of output mappings from old node to new node.
|
||||
"""
|
||||
def __init__(self,
|
||||
new_node_id: str,
|
||||
old_node_id: str,
|
||||
old_widget_ids: list[str] | None=None,
|
||||
input_mapping: list[InputMap] | None=None,
|
||||
output_mapping: list[OutputMap] | None=None,
|
||||
):
|
||||
self.new_node_id = new_node_id
|
||||
self.old_node_id = old_node_id
|
||||
self.old_widget_ids = old_widget_ids
|
||||
self.input_mapping = input_mapping
|
||||
self.output_mapping = output_mapping
|
||||
|
||||
def as_dict(self):
|
||||
"""Create serializable representation of the node replacement."""
|
||||
return {
|
||||
"new_node_id": self.new_node_id,
|
||||
"old_node_id": self.old_node_id,
|
||||
"old_widget_ids": self.old_widget_ids,
|
||||
"input_mapping": list(self.input_mapping) if self.input_mapping else None,
|
||||
"output_mapping": list(self.output_mapping) if self.output_mapping else None,
|
||||
}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FolderType",
|
||||
"UploadType",
|
||||
"RemoteOptions",
|
||||
"NumberDisplay",
|
||||
"ControlAfterGenerate",
|
||||
|
||||
"comfytype",
|
||||
"Custom",
|
||||
@@ -2190,5 +2121,4 @@ __all__ = [
|
||||
"ImageCompare",
|
||||
"PriceBadgeDepends",
|
||||
"PriceBadge",
|
||||
"NodeReplace",
|
||||
]
|
||||
|
||||
@@ -45,55 +45,17 @@ class BriaEditImageRequest(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class BriaRemoveBackgroundRequest(BaseModel):
|
||||
image: str = Field(...)
|
||||
sync: bool = Field(False)
|
||||
visual_input_content_moderation: bool = Field(
|
||||
False, description="If true, returns 422 on input image moderation failure."
|
||||
)
|
||||
visual_output_content_moderation: bool = Field(
|
||||
False, description="If true, returns 422 on visual output moderation failure."
|
||||
)
|
||||
seed: int = Field(...)
|
||||
|
||||
|
||||
class BriaStatusResponse(BaseModel):
|
||||
request_id: str = Field(...)
|
||||
status_url: str = Field(...)
|
||||
warning: str | None = Field(None)
|
||||
|
||||
|
||||
class BriaRemoveBackgroundResult(BaseModel):
|
||||
image_url: str = Field(...)
|
||||
|
||||
|
||||
class BriaRemoveBackgroundResponse(BaseModel):
|
||||
status: str = Field(...)
|
||||
result: BriaRemoveBackgroundResult | None = Field(None)
|
||||
|
||||
|
||||
class BriaImageEditResult(BaseModel):
|
||||
class BriaResult(BaseModel):
|
||||
structured_prompt: str = Field(...)
|
||||
image_url: str = Field(...)
|
||||
|
||||
|
||||
class BriaImageEditResponse(BaseModel):
|
||||
class BriaResponse(BaseModel):
|
||||
status: str = Field(...)
|
||||
result: BriaImageEditResult | None = Field(None)
|
||||
|
||||
|
||||
class BriaRemoveVideoBackgroundRequest(BaseModel):
|
||||
video: str = Field(...)
|
||||
background_color: str = Field(default="transparent", description="Background color for the output video.")
|
||||
output_container_and_codec: str = Field(...)
|
||||
preserve_audio: bool = Field(True)
|
||||
seed: int = Field(...)
|
||||
|
||||
|
||||
class BriaRemoveVideoBackgroundResult(BaseModel):
|
||||
video_url: str = Field(...)
|
||||
|
||||
|
||||
class BriaRemoveVideoBackgroundResponse(BaseModel):
|
||||
status: str = Field(...)
|
||||
result: BriaRemoveVideoBackgroundResult | None = Field(None)
|
||||
result: BriaResult | None = Field(None)
|
||||
|
||||
@@ -64,23 +64,3 @@ class To3DProTaskResultResponse(BaseModel):
|
||||
|
||||
class To3DProTaskQueryRequest(BaseModel):
|
||||
JobId: str = Field(...)
|
||||
|
||||
|
||||
class To3DUVFileInput(BaseModel):
|
||||
Type: str = Field(..., description="File type: GLB, OBJ, or FBX")
|
||||
Url: str = Field(...)
|
||||
|
||||
|
||||
class To3DUVTaskRequest(BaseModel):
|
||||
File: To3DUVFileInput = Field(...)
|
||||
|
||||
|
||||
class TextureEditImageInfo(BaseModel):
|
||||
Url: str = Field(...)
|
||||
|
||||
|
||||
class TextureEditTaskRequest(BaseModel):
|
||||
File3D: To3DUVFileInput = Field(...)
|
||||
Image: TextureEditImageInfo | None = Field(None)
|
||||
Prompt: str | None = Field(None)
|
||||
EnablePBR: bool | None = Field(None)
|
||||
|
||||
@@ -198,6 +198,11 @@ dict_recraft_substyles_v3 = {
|
||||
}
|
||||
|
||||
|
||||
class RecraftModel(str, Enum):
|
||||
recraftv3 = 'recraftv3'
|
||||
recraftv2 = 'recraftv2'
|
||||
|
||||
|
||||
class RecraftImageSize(str, Enum):
|
||||
res_1024x1024 = '1024x1024'
|
||||
res_1365x1024 = '1365x1024'
|
||||
@@ -216,41 +221,6 @@ class RecraftImageSize(str, Enum):
|
||||
res_1707x1024 = '1707x1024'
|
||||
|
||||
|
||||
RECRAFT_V4_SIZES = [
|
||||
"1024x1024",
|
||||
"1536x768",
|
||||
"768x1536",
|
||||
"1280x832",
|
||||
"832x1280",
|
||||
"1216x896",
|
||||
"896x1216",
|
||||
"1152x896",
|
||||
"896x1152",
|
||||
"832x1344",
|
||||
"1280x896",
|
||||
"896x1280",
|
||||
"1344x768",
|
||||
"768x1344",
|
||||
]
|
||||
|
||||
RECRAFT_V4_PRO_SIZES = [
|
||||
"2048x2048",
|
||||
"3072x1536",
|
||||
"1536x3072",
|
||||
"2560x1664",
|
||||
"1664x2560",
|
||||
"2432x1792",
|
||||
"1792x2432",
|
||||
"2304x1792",
|
||||
"1792x2304",
|
||||
"1664x2688",
|
||||
"1434x1024",
|
||||
"1024x1434",
|
||||
"2560x1792",
|
||||
"1792x2560",
|
||||
]
|
||||
|
||||
|
||||
class RecraftColorObject(BaseModel):
|
||||
rgb: list[int] = Field(..., description='An array of 3 integer values in range of 0...255 defining RGB Color Model')
|
||||
|
||||
@@ -264,16 +234,17 @@ class RecraftControlsObject(BaseModel):
|
||||
|
||||
class RecraftImageGenerationRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt describing the image to generate')
|
||||
size: str | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
|
||||
size: RecraftImageSize | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
|
||||
n: int = Field(..., description='The number of images to generate')
|
||||
negative_prompt: str | None = Field(None, description='A text description of undesired elements on an image')
|
||||
model: str = Field(...)
|
||||
model: RecraftModel | None = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
|
||||
style: str | None = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
|
||||
substyle: str | None = Field(None, description='The substyle to apply to the generated image, depending on the style input')
|
||||
controls: RecraftControlsObject | None = Field(None, description='A set of custom parameters to tweak generation process')
|
||||
style_id: str | None = Field(None, description='Use a previously uploaded style as a reference; UUID')
|
||||
strength: float | None = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
|
||||
random_seed: int | None = Field(None, description="Seed for video generation")
|
||||
# text_layout
|
||||
|
||||
|
||||
class RecraftReturnedObject(BaseModel):
|
||||
|
||||
@@ -3,11 +3,7 @@ from typing_extensions import override
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bria import (
|
||||
BriaEditImageRequest,
|
||||
BriaRemoveBackgroundRequest,
|
||||
BriaRemoveBackgroundResponse,
|
||||
BriaRemoveVideoBackgroundRequest,
|
||||
BriaRemoveVideoBackgroundResponse,
|
||||
BriaImageEditResponse,
|
||||
BriaResponse,
|
||||
BriaStatusResponse,
|
||||
InputModerationSettings,
|
||||
)
|
||||
@@ -15,12 +11,10 @@ from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
convert_mask_to_image,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_image_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_video_duration,
|
||||
upload_images_to_comfyapi,
|
||||
)
|
||||
|
||||
|
||||
@@ -79,15 +73,21 @@ class BriaImageEditNode(IO.ComfyNode):
|
||||
IO.DynamicCombo.Input(
|
||||
"moderation",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("false", []),
|
||||
IO.DynamicCombo.Option(
|
||||
"true",
|
||||
[
|
||||
IO.Boolean.Input("prompt_content_moderation", default=False),
|
||||
IO.Boolean.Input("visual_input_moderation", default=False),
|
||||
IO.Boolean.Input("visual_output_moderation", default=True),
|
||||
IO.Boolean.Input(
|
||||
"prompt_content_moderation", default=False
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"visual_input_moderation", default=False
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"visual_output_moderation", default=True
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option("false", []),
|
||||
],
|
||||
tooltip="Moderation settings",
|
||||
),
|
||||
@@ -127,26 +127,50 @@ class BriaImageEditNode(IO.ComfyNode):
|
||||
mask: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if not prompt and not structured_prompt:
|
||||
raise ValueError("One of prompt or structured_prompt is required to be non-empty.")
|
||||
raise ValueError(
|
||||
"One of prompt or structured_prompt is required to be non-empty."
|
||||
)
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Exactly one input image is required.")
|
||||
mask_url = None
|
||||
if mask is not None:
|
||||
mask_url = await upload_image_to_comfyapi(cls, convert_mask_to_image(mask), wait_label="Uploading mask")
|
||||
mask_url = (
|
||||
await upload_images_to_comfyapi(
|
||||
cls,
|
||||
convert_mask_to_image(mask),
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading mask",
|
||||
)
|
||||
)[0]
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"),
|
||||
data=BriaEditImageRequest(
|
||||
instruction=prompt if prompt else None,
|
||||
structured_instruction=structured_prompt if structured_prompt else None,
|
||||
images=[await upload_image_to_comfyapi(cls, image, wait_label="Uploading image")],
|
||||
images=await upload_images_to_comfyapi(
|
||||
cls,
|
||||
image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading image",
|
||||
),
|
||||
mask=mask_url,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
guidance_scale=guidance_scale,
|
||||
seed=seed,
|
||||
model_version=model,
|
||||
steps_num=steps,
|
||||
prompt_content_moderation=moderation.get("prompt_content_moderation", False),
|
||||
visual_input_content_moderation=moderation.get("visual_input_moderation", False),
|
||||
visual_output_content_moderation=moderation.get("visual_output_moderation", False),
|
||||
prompt_content_moderation=moderation.get(
|
||||
"prompt_content_moderation", False
|
||||
),
|
||||
visual_input_content_moderation=moderation.get(
|
||||
"visual_input_moderation", False
|
||||
),
|
||||
visual_output_content_moderation=moderation.get(
|
||||
"visual_output_moderation", False
|
||||
),
|
||||
),
|
||||
response_model=BriaStatusResponse,
|
||||
)
|
||||
@@ -154,7 +178,7 @@ class BriaImageEditNode(IO.ComfyNode):
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
response_model=BriaImageEditResponse,
|
||||
response_model=BriaResponse,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_image_tensor(response.result.image_url),
|
||||
@@ -162,167 +186,11 @@ class BriaImageEditNode(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
class BriaRemoveImageBackground(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="BriaRemoveImageBackground",
|
||||
display_name="Bria Remove Image Background",
|
||||
category="api node/image/Bria",
|
||||
description="Remove the background from an image using Bria RMBG 2.0.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.DynamicCombo.Input(
|
||||
"moderation",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("false", []),
|
||||
IO.DynamicCombo.Option(
|
||||
"true",
|
||||
[
|
||||
IO.Boolean.Input("visual_input_moderation", default=False),
|
||||
IO.Boolean.Input("visual_output_moderation", default=True),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Moderation settings",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.018}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
moderation: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bria/v2/image/edit/remove_background", method="POST"),
|
||||
data=BriaRemoveBackgroundRequest(
|
||||
image=await upload_image_to_comfyapi(cls, image, wait_label="Uploading image"),
|
||||
sync=False,
|
||||
visual_input_content_moderation=moderation.get("visual_input_moderation", False),
|
||||
visual_output_content_moderation=moderation.get("visual_output_moderation", False),
|
||||
seed=seed,
|
||||
),
|
||||
response_model=BriaStatusResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
response_model=BriaRemoveBackgroundResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result.image_url))
|
||||
|
||||
|
||||
class BriaRemoveVideoBackground(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="BriaRemoveVideoBackground",
|
||||
display_name="Bria Remove Video Background",
|
||||
category="api node/video/Bria",
|
||||
description="Remove the background from a video using Bria. ",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
IO.Combo.Input(
|
||||
"background_color",
|
||||
options=[
|
||||
"Black",
|
||||
"White",
|
||||
"Gray",
|
||||
"Red",
|
||||
"Green",
|
||||
"Blue",
|
||||
"Yellow",
|
||||
"Cyan",
|
||||
"Magenta",
|
||||
"Orange",
|
||||
],
|
||||
tooltip="Background color for the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: Input.Video,
|
||||
background_color: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_video_duration(video, max_duration=60.0)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bria/v2/video/edit/remove_background", method="POST"),
|
||||
data=BriaRemoveVideoBackgroundRequest(
|
||||
video=await upload_video_to_comfyapi(cls, video),
|
||||
background_color=background_color,
|
||||
output_container_and_codec="mp4_h264",
|
||||
seed=seed,
|
||||
),
|
||||
response_model=BriaStatusResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
response_model=BriaRemoveVideoBackgroundResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
|
||||
|
||||
|
||||
class BriaExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
BriaImageEditNode,
|
||||
BriaRemoveImageBackground,
|
||||
BriaRemoveVideoBackground,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,48 +1,31 @@
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.hunyuan3d import (
|
||||
Hunyuan3DViewImage,
|
||||
InputGenerateType,
|
||||
ResultFile3D,
|
||||
TextureEditTaskRequest,
|
||||
To3DProTaskCreateResponse,
|
||||
To3DProTaskQueryRequest,
|
||||
To3DProTaskRequest,
|
||||
To3DProTaskResultResponse,
|
||||
To3DUVFileInput,
|
||||
To3DUVTaskRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_file_3d,
|
||||
download_url_to_image_tensor,
|
||||
downscale_image_tensor_by_max_side,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_3d_model_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
|
||||
def _is_tencent_rate_limited(status: int, body: object) -> bool:
|
||||
return (
|
||||
status == 400
|
||||
and isinstance(body, dict)
|
||||
and "RequestLimitExceeded" in str(body.get("Response", {}).get("Error", {}).get("Code", ""))
|
||||
)
|
||||
|
||||
|
||||
def get_file_from_response(
|
||||
response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True
|
||||
) -> ResultFile3D | None:
|
||||
def get_file_from_response(response_objs: list[ResultFile3D], file_type: str) -> ResultFile3D | None:
|
||||
for i in response_objs:
|
||||
if i.Type.lower() == file_type.lower():
|
||||
return i
|
||||
if raise_if_not_found:
|
||||
raise ValueError(f"'{file_type}' file type is not found in the response.")
|
||||
return None
|
||||
|
||||
|
||||
@@ -52,7 +35,7 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TencentTextToModelNode",
|
||||
display_name="Hunyuan3D: Text to Model",
|
||||
display_name="Hunyuan3D: Text to Model (Pro)",
|
||||
category="api node/3d/Tencent",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@@ -137,7 +120,6 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
EnablePBR=generate_type.get("pbr", None),
|
||||
PolygonType=generate_type.get("polygon_type", None),
|
||||
),
|
||||
is_rate_limited=_is_tencent_rate_limited,
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||
@@ -149,14 +131,11 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
|
||||
obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
|
||||
file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
||||
),
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
|
||||
),
|
||||
file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
|
||||
)
|
||||
|
||||
|
||||
@@ -166,7 +145,7 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TencentImageToModelNode",
|
||||
display_name="Hunyuan3D: Image(s) to Model",
|
||||
display_name="Hunyuan3D: Image(s) to Model (Pro)",
|
||||
category="api node/3d/Tencent",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@@ -289,7 +268,6 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
EnablePBR=generate_type.get("pbr", None),
|
||||
PolygonType=generate_type.get("polygon_type", None),
|
||||
),
|
||||
is_rate_limited=_is_tencent_rate_limited,
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||
@@ -301,257 +279,11 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
|
||||
obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
|
||||
file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
||||
),
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TencentModelTo3DUVNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TencentModelTo3DUVNode",
|
||||
display_name="Hunyuan3D: Model to UV",
|
||||
category="api node/3d/Tencent",
|
||||
description="Perform UV unfolding on a 3D model to generate UV texture. "
|
||||
"Input model must have less than 30000 faces.",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
"model_3d",
|
||||
types=[IO.File3DGLB, IO.File3DOBJ, IO.File3DFBX, IO.File3DAny],
|
||||
tooltip="Input 3D model (GLB, OBJ, or FBX)",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=1,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.2}'),
|
||||
)
|
||||
|
||||
SUPPORTED_FORMATS = {"glb", "obj", "fbx"}
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_3d: Types.File3D,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
_ = seed
|
||||
file_format = model_3d.format.lower()
|
||||
if file_format not in cls.SUPPORTED_FORMATS:
|
||||
raise ValueError(
|
||||
f"Unsupported file format: '{file_format}'. "
|
||||
f"Supported formats: {', '.join(sorted(cls.SUPPORTED_FORMATS))}."
|
||||
)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"),
|
||||
response_model=To3DProTaskCreateResponse,
|
||||
data=To3DUVTaskRequest(
|
||||
File=To3DUVFileInput(
|
||||
Type=file_format.upper(),
|
||||
Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format),
|
||||
)
|
||||
),
|
||||
is_rate_limited=_is_tencent_rate_limited,
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv/query", method="POST"),
|
||||
data=To3DProTaskQueryRequest(JobId=response.JobId),
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
||||
await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "image").Url),
|
||||
)
|
||||
|
||||
|
||||
class Tencent3DTextureEditNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Tencent3DTextureEditNode",
|
||||
display_name="Hunyuan3D: 3D Texture Edit",
|
||||
category="api node/3d/Tencent",
|
||||
description="After inputting the 3D model, perform 3D model texture redrawing.",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
"model_3d",
|
||||
types=[IO.File3DFBX, IO.File3DAny],
|
||||
tooltip="3D model in FBX format. Model should have less than 100000 faces.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Describes texture editing. Supports up to 1024 UTF-8 characters.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd": 0.6}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_3d: Types.File3D,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
_ = seed
|
||||
file_format = model_3d.format.lower()
|
||||
if file_format != "fbx":
|
||||
raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.")
|
||||
validate_string(prompt, field_name="prompt", min_length=1, max_length=1024)
|
||||
model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"),
|
||||
response_model=To3DProTaskCreateResponse,
|
||||
data=TextureEditTaskRequest(
|
||||
File3D=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
|
||||
Prompt=prompt,
|
||||
EnablePBR=True,
|
||||
),
|
||||
is_rate_limited=_is_tencent_rate_limited,
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit/query", method="POST"),
|
||||
data=To3DProTaskQueryRequest(JobId=response.JobId),
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"),
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
||||
)
|
||||
|
||||
|
||||
class Tencent3DPartNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Tencent3DPartNode",
|
||||
display_name="Hunyuan3D: 3D Part",
|
||||
category="api node/3d/Tencent",
|
||||
description="Automatically perform component identification and generation based on the model structure.",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
"model_3d",
|
||||
types=[IO.File3DFBX, IO.File3DAny],
|
||||
tooltip="3D model in FBX format. Model should have less than 30000 faces.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.6}'),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_3d: Types.File3D,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
_ = seed
|
||||
file_format = model_3d.format.lower()
|
||||
if file_format != "fbx":
|
||||
raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.")
|
||||
model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"),
|
||||
response_model=To3DProTaskCreateResponse,
|
||||
data=To3DUVTaskRequest(
|
||||
File=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
|
||||
),
|
||||
is_rate_limited=_is_tencent_rate_limited,
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part/query", method="POST"),
|
||||
data=To3DProTaskQueryRequest(JobId=response.JobId),
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
||||
file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
|
||||
)
|
||||
|
||||
|
||||
@@ -561,9 +293,6 @@ class TencentHunyuan3DExtension(ComfyExtension):
|
||||
return [
|
||||
TencentTextToModelNode,
|
||||
TencentImageToModelNode,
|
||||
# TencentModelTo3DUVNode,
|
||||
# Tencent3DTextureEditNode,
|
||||
Tencent3DPartNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ class SupportedOpenAIModel(str, Enum):
|
||||
o1 = "o1"
|
||||
o3 = "o3"
|
||||
o1_pro = "o1-pro"
|
||||
gpt_4o = "gpt-4o"
|
||||
gpt_4_1 = "gpt-4.1"
|
||||
gpt_4_1_mini = "gpt-4.1-mini"
|
||||
gpt_4_1_nano = "gpt-4.1-nano"
|
||||
@@ -648,6 +649,11 @@ class OpenAIChatNode(IO.ComfyNode):
|
||||
"usd": [0.01, 0.04],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gpt-4o") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.0025, 0.01],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gpt-4.1-nano") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.0001, 0.0004],
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from io import BytesIO
|
||||
from typing import Optional, Union
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
@@ -8,8 +9,6 @@ from typing_extensions import override
|
||||
from comfy.utils import ProgressBar
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.recraft import (
|
||||
RECRAFT_V4_PRO_SIZES,
|
||||
RECRAFT_V4_SIZES,
|
||||
RecraftColor,
|
||||
RecraftColorChain,
|
||||
RecraftControls,
|
||||
@@ -19,6 +18,7 @@ from comfy_api_nodes.apis.recraft import (
|
||||
RecraftImageGenerationResponse,
|
||||
RecraftImageSize,
|
||||
RecraftIO,
|
||||
RecraftModel,
|
||||
RecraftStyle,
|
||||
RecraftStyleV3,
|
||||
get_v3_substyles,
|
||||
@@ -39,7 +39,7 @@ async def handle_recraft_file_request(
|
||||
cls: type[IO.ComfyNode],
|
||||
image: torch.Tensor,
|
||||
path: str,
|
||||
mask: torch.Tensor | None = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
total_pixels: int = 4096 * 4096,
|
||||
timeout: int = 1024,
|
||||
request=None,
|
||||
@@ -73,11 +73,11 @@ async def handle_recraft_file_request(
|
||||
def recraft_multipart_parser(
|
||||
data,
|
||||
parent_key=None,
|
||||
formatter: type[callable] | None = None,
|
||||
converted_to_check: list[list] | None = None,
|
||||
formatter: Optional[type[callable]] = None,
|
||||
converted_to_check: Optional[list[list]] = None,
|
||||
is_list: bool = False,
|
||||
return_mode: str = "formdata", # "dict" | "formdata"
|
||||
) -> dict | aiohttp.FormData:
|
||||
) -> Union[dict, aiohttp.FormData]:
|
||||
"""
|
||||
Formats data such that multipart/form-data will work with aiohttp library when both files and data are present.
|
||||
|
||||
@@ -309,7 +309,7 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode):
|
||||
node_id="RecraftStyleV3InfiniteStyleLibrary",
|
||||
display_name="Recraft Style - Infinite Style Library",
|
||||
category="api node/image/Recraft",
|
||||
description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.",
|
||||
description="Select style based on preexisting UUID from Recraft's Infinite Style Library.",
|
||||
inputs=[
|
||||
IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."),
|
||||
],
|
||||
@@ -485,7 +485,7 @@ class RecraftTextToImageNode(IO.ComfyNode):
|
||||
data=RecraftImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
model="recraftv3",
|
||||
model=RecraftModel.recraftv3,
|
||||
size=size,
|
||||
n=n,
|
||||
style=recraft_style.style,
|
||||
@@ -598,7 +598,7 @@ class RecraftImageToImageNode(IO.ComfyNode):
|
||||
request = RecraftImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
model="recraftv3",
|
||||
model=RecraftModel.recraftv3,
|
||||
n=n,
|
||||
strength=round(strength, 2),
|
||||
style=recraft_style.style,
|
||||
@@ -698,7 +698,7 @@ class RecraftImageInpaintingNode(IO.ComfyNode):
|
||||
request = RecraftImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
model="recraftv3",
|
||||
model=RecraftModel.recraftv3,
|
||||
n=n,
|
||||
style=recraft_style.style,
|
||||
substyle=recraft_style.substyle,
|
||||
@@ -810,7 +810,7 @@ class RecraftTextToVectorNode(IO.ComfyNode):
|
||||
data=RecraftImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
model="recraftv3",
|
||||
model=RecraftModel.recraftv3,
|
||||
size=size,
|
||||
n=n,
|
||||
style=recraft_style.style,
|
||||
@@ -933,7 +933,7 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode):
|
||||
request = RecraftImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
model="recraftv3",
|
||||
model=RecraftModel.recraftv3,
|
||||
n=n,
|
||||
style=recraft_style.style,
|
||||
substyle=recraft_style.substyle,
|
||||
@@ -1078,252 +1078,6 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode):
|
||||
)
|
||||
|
||||
|
||||
class RecraftV4TextToImageNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RecraftV4TextToImageNode",
|
||||
display_name="Recraft V4 Text to Image",
|
||||
category="api node/image/Recraft",
|
||||
description="Generates images using Recraft V4 or V4 Pro models.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Prompt for the image generation. Maximum 10,000 characters.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
tooltip="An optional text description of undesired elements on an image.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"recraftv4",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
options=RECRAFT_V4_SIZES,
|
||||
default="1024x1024",
|
||||
tooltip="The size of the generated image.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"recraftv4_pro",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
options=RECRAFT_V4_PRO_SIZES,
|
||||
default="2048x2048",
|
||||
tooltip="The size of the generated image.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="The model to use for generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"n",
|
||||
default=1,
|
||||
min=1,
|
||||
max=6,
|
||||
tooltip="The number of images to generate.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
IO.Custom(RecraftIO.CONTROLS).Input(
|
||||
"recraft_controls",
|
||||
tooltip="Optional additional controls over the generation via the Recraft Controls node.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {"recraftv4": 0.04, "recraftv4_pro": 0.25};
|
||||
{"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
model: dict,
|
||||
n: int,
|
||||
seed: int,
|
||||
recraft_controls: RecraftControls | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"),
|
||||
response_model=RecraftImageGenerationResponse,
|
||||
data=RecraftImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
model=model["model"],
|
||||
size=model["size"],
|
||||
n=n,
|
||||
controls=recraft_controls.create_api_model() if recraft_controls else None,
|
||||
),
|
||||
max_retries=1,
|
||||
)
|
||||
images = []
|
||||
for data in response.data:
|
||||
with handle_recraft_image_output():
|
||||
image = bytesio_to_image_tensor(await download_url_as_bytesio(data.url, timeout=1024))
|
||||
if len(image.shape) < 4:
|
||||
image = image.unsqueeze(0)
|
||||
images.append(image)
|
||||
return IO.NodeOutput(torch.cat(images, dim=0))
|
||||
|
||||
|
||||
class RecraftV4TextToVectorNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RecraftV4TextToVectorNode",
|
||||
display_name="Recraft V4 Text to Vector",
|
||||
category="api node/image/Recraft",
|
||||
description="Generates SVG using Recraft V4 or V4 Pro models.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Prompt for the image generation. Maximum 10,000 characters.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
tooltip="An optional text description of undesired elements on an image.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"recraftv4",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
options=RECRAFT_V4_SIZES,
|
||||
default="1024x1024",
|
||||
tooltip="The size of the generated image.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"recraftv4_pro",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
options=RECRAFT_V4_PRO_SIZES,
|
||||
default="2048x2048",
|
||||
tooltip="The size of the generated image.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="The model to use for generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"n",
|
||||
default=1,
|
||||
min=1,
|
||||
max=6,
|
||||
tooltip="The number of images to generate.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
IO.Custom(RecraftIO.CONTROLS).Input(
|
||||
"recraft_controls",
|
||||
tooltip="Optional additional controls over the generation via the Recraft Controls node.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.SVG.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {"recraftv4": 0.08, "recraftv4_pro": 0.30};
|
||||
{"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
model: dict,
|
||||
n: int,
|
||||
seed: int,
|
||||
recraft_controls: RecraftControls | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"),
|
||||
response_model=RecraftImageGenerationResponse,
|
||||
data=RecraftImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
model=model["model"],
|
||||
size=model["size"],
|
||||
n=n,
|
||||
style="vector_illustration",
|
||||
substyle=None,
|
||||
controls=recraft_controls.create_api_model() if recraft_controls else None,
|
||||
),
|
||||
max_retries=1,
|
||||
)
|
||||
svg_data = []
|
||||
for data in response.data:
|
||||
svg_data.append(await download_url_as_bytesio(data.url, timeout=1024))
|
||||
return IO.NodeOutput(SVG(svg_data))
|
||||
|
||||
|
||||
class RecraftExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@@ -1344,8 +1098,6 @@ class RecraftExtension(ComfyExtension):
|
||||
RecraftCreateStyleNode,
|
||||
RecraftColorRGBNode,
|
||||
RecraftControlsNode,
|
||||
RecraftV4TextToImageNode,
|
||||
RecraftV4TextToVectorNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -54,7 +54,6 @@ async def execute_task(
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.state,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
price_extractor=lambda r: r.credits * 0.005 if r.credits is not None else None,
|
||||
max_poll_attempts=max_poll_attempts,
|
||||
)
|
||||
if not response.creations:
|
||||
@@ -1307,36 +1306,6 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"viduq3-turbo",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "9:16", "3:4", "4:3", "1:1"],
|
||||
tooltip="The aspect ratio of the output video.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720p", "1080p"],
|
||||
tooltip="Resolution of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=1,
|
||||
max=16,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Duration of the output video in seconds.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"audio",
|
||||
default=False,
|
||||
tooltip="When enabled, outputs video with sound "
|
||||
"(including dialogue and sound effects).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for video generation.",
|
||||
),
|
||||
@@ -1365,20 +1334,13 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$d := $lookup(widgets, "model.duration");
|
||||
$contains(widgets.model, "turbo")
|
||||
? (
|
||||
$rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
|
||||
{"type":"usd","usd": $rate * $d}
|
||||
)
|
||||
: (
|
||||
$rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res);
|
||||
{"type":"usd","usd": $rate * $d}
|
||||
)
|
||||
$base := $lookup({"720p": 0.075, "1080p": 0.1}, $res);
|
||||
$perSec := $lookup({"720p": 0.025, "1080p": 0.05}, $res);
|
||||
{"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@@ -1447,31 +1409,6 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"viduq3-turbo",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720p", "1080p"],
|
||||
tooltip="Resolution of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=1,
|
||||
max=16,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Duration of the output video in seconds.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"audio",
|
||||
default=False,
|
||||
tooltip="When enabled, outputs video with sound "
|
||||
"(including dialogue and sound effects).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for video generation.",
|
||||
),
|
||||
@@ -1505,20 +1442,13 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$d := $lookup(widgets, "model.duration");
|
||||
$contains(widgets.model, "turbo")
|
||||
? (
|
||||
$rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
|
||||
{"type":"usd","usd": $rate * $d}
|
||||
)
|
||||
: (
|
||||
$rate := $lookup({"720p": 0.15, "1080p": 0.16, "2k": 0.2}, $res);
|
||||
{"type":"usd","usd": $rate * $d}
|
||||
)
|
||||
$base := $lookup({"720p": 0.075, "1080p": 0.275, "2k": 0.35}, $res);
|
||||
$perSec := $lookup({"720p": 0.05, "1080p": 0.075, "2k": 0.075}, $res);
|
||||
{"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@@ -1551,145 +1481,6 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class Vidu3StartEndToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Vidu3StartEndToVideoNode",
|
||||
display_name="Vidu Q3 Start/End Frame-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate a video from a start frame, an end frame, and a prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"viduq3-pro",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720p", "1080p"],
|
||||
tooltip="Resolution of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=1,
|
||||
max=16,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Duration of the output video in seconds.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"audio",
|
||||
default=False,
|
||||
tooltip="When enabled, outputs video with sound "
|
||||
"(including dialogue and sound effects).",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"viduq3-turbo",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720p", "1080p"],
|
||||
tooltip="Resolution of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=1,
|
||||
max=16,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Duration of the output video in seconds.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"audio",
|
||||
default=False,
|
||||
tooltip="When enabled, outputs video with sound "
|
||||
"(including dialogue and sound effects).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for video generation.",
|
||||
),
|
||||
IO.Image.Input("first_frame"),
|
||||
IO.Image.Input("end_frame"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Prompt description (max 2000 characters).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=1,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$d := $lookup(widgets, "model.duration");
|
||||
$contains(widgets.model, "turbo")
|
||||
? (
|
||||
$rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
|
||||
{"type":"usd","usd": $rate * $d}
|
||||
)
|
||||
: (
|
||||
$rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res);
|
||||
{"type":"usd","usd": $rate * $d}
|
||||
)
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: dict,
|
||||
first_frame: Input.Image,
|
||||
end_frame: Input.Image,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=2000)
|
||||
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
|
||||
payload = TaskCreationRequest(
|
||||
model=model["model"],
|
||||
prompt=prompt,
|
||||
duration=model["duration"],
|
||||
seed=seed,
|
||||
resolution=model["resolution"],
|
||||
audio=model["audio"],
|
||||
images=[
|
||||
(await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
|
||||
for frame in (first_frame, end_frame)
|
||||
],
|
||||
)
|
||||
results = await execute_task(cls, VIDU_START_END_VIDEO, payload)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class ViduExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@@ -1706,7 +1497,6 @@ class ViduExtension(ComfyExtension):
|
||||
ViduMultiFrameVideoNode,
|
||||
Vidu3TextToVideoNode,
|
||||
Vidu3ImageToVideoNode,
|
||||
Vidu3StartEndToVideoNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -33,7 +33,6 @@ from .download_helpers import (
|
||||
download_url_to_video_output,
|
||||
)
|
||||
from .upload_helpers import (
|
||||
upload_3d_model_to_comfyapi,
|
||||
upload_audio_to_comfyapi,
|
||||
upload_file_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
@@ -63,7 +62,6 @@ __all__ = [
|
||||
"sync_op",
|
||||
"sync_op_raw",
|
||||
# Upload helpers
|
||||
"upload_3d_model_to_comfyapi",
|
||||
"upload_audio_to_comfyapi",
|
||||
"upload_file_to_comfyapi",
|
||||
"upload_image_to_comfyapi",
|
||||
|
||||
@@ -57,7 +57,7 @@ def tensor_to_bytesio(
|
||||
image: torch.Tensor,
|
||||
*,
|
||||
total_pixels: int | None = 2048 * 2048,
|
||||
mime_type: str | None = "image/png",
|
||||
mime_type: str = "image/png",
|
||||
) -> BytesIO:
|
||||
"""Converts a torch.Tensor image to a named BytesIO object.
|
||||
|
||||
|
||||
@@ -164,27 +164,6 @@ async def upload_video_to_comfyapi(
|
||||
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
|
||||
|
||||
|
||||
_3D_MIME_TYPES = {
|
||||
"glb": "model/gltf-binary",
|
||||
"obj": "model/obj",
|
||||
"fbx": "application/octet-stream",
|
||||
}
|
||||
|
||||
|
||||
async def upload_3d_model_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
model_3d: Types.File3D,
|
||||
file_format: str,
|
||||
) -> str:
|
||||
"""Uploads a 3D model file to ComfyUI API and returns its download URL."""
|
||||
return await upload_file_to_comfyapi(
|
||||
cls,
|
||||
model_3d.get_data(),
|
||||
f"{uuid.uuid4()}.{file_format}",
|
||||
_3D_MIME_TYPES.get(file_format, "application/octet-stream"),
|
||||
)
|
||||
|
||||
|
||||
async def upload_file_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
file_bytes_io: BytesIO,
|
||||
|
||||
@@ -7,7 +7,6 @@ import logging
|
||||
from enum import Enum
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from tqdm.auto import trange
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
|
||||
@@ -50,22 +49,12 @@ LORA_TYPES = {"standard": LORAType.STANDARD,
|
||||
"full_diff": LORAType.FULL_DIFF}
|
||||
|
||||
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
|
||||
comfy.model_management.load_models_gpu([model_diff])
|
||||
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
|
||||
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
|
||||
|
||||
sd_keys = list(sd.keys())
|
||||
for index in trange(len(sd_keys), unit="weight"):
|
||||
k = sd_keys[index]
|
||||
op_keys = sd_keys[index].rsplit('.', 1)
|
||||
if len(op_keys) < 2 or op_keys[1] not in ["weight", "bias"] or (op_keys[1] == "bias" and not bias_diff):
|
||||
continue
|
||||
op = comfy.utils.get_attr(model_diff.model, op_keys[0])
|
||||
if hasattr(op, "comfy_cast_weights") and not getattr(op, "comfy_patched_weights", False):
|
||||
weight_diff = model_diff.patch_weight_to_device(k, model_diff.load_device, return_weight=True)
|
||||
else:
|
||||
for k in sd:
|
||||
if k.endswith(".weight"):
|
||||
weight_diff = sd[k]
|
||||
|
||||
if op_keys[1] == "weight":
|
||||
if lora_type == LORAType.STANDARD:
|
||||
if weight_diff.ndim < 2:
|
||||
if bias_diff:
|
||||
@@ -80,8 +69,8 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
|
||||
elif lora_type == LORAType.FULL_DIFF:
|
||||
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||
|
||||
elif bias_diff and op_keys[1] == "bias":
|
||||
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = weight_diff.contiguous().half().cpu()
|
||||
elif bias_diff and k.endswith(".bias"):
|
||||
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
||||
return output_sd
|
||||
|
||||
class LoraSave(io.ComfyNode):
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
import torch
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class NAGuidance(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="NAGuidance",
|
||||
display_name="Normalized Attention Guidance",
|
||||
description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.",
|
||||
category="",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to apply NAG to."),
|
||||
io.Float.Input("nag_scale", min=0.0, default=5.0, max=50.0, step=0.1, tooltip="The guidance scale factor. Higher values push further from the negative prompt."),
|
||||
io.Float.Input("nag_alpha", min=0.0, default=0.5, max=1.0, step=0.01, tooltip="Blending factor for the normalized attention. 1.0 is full replacement, 0.0 is no effect."),
|
||||
io.Float.Input("nag_tau", min=1.0, default=1.5, max=10.0, step=0.01),
|
||||
# io.Float.Input("start_percent", min=0.0, default=0.0, max=1.0, step=0.01, tooltip="The relative sampling step to begin applying NAG."),
|
||||
# io.Float.Input("end_percent", min=0.0, default=1.0, max=1.0, step=0.01, tooltip="The relative sampling step to stop applying NAG."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The patched model with NAG enabled."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, nag_scale: float, nag_alpha: float, nag_tau: float) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
|
||||
# sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_percent)
|
||||
# sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_percent)
|
||||
|
||||
def nag_attention_output_patch(out, extra_options):
|
||||
cond_or_uncond = extra_options.get("cond_or_uncond", None)
|
||||
if cond_or_uncond is None:
|
||||
return out
|
||||
|
||||
if not (1 in cond_or_uncond and 0 in cond_or_uncond):
|
||||
return out
|
||||
|
||||
# sigma = extra_options.get("sigmas", None)
|
||||
# if sigma is not None and len(sigma) > 0:
|
||||
# sigma = sigma[0].item()
|
||||
# if sigma > sigma_start or sigma < sigma_end:
|
||||
# return out
|
||||
|
||||
img_slice = extra_options.get("img_slice", None)
|
||||
|
||||
if img_slice is not None:
|
||||
orig_out = out
|
||||
out = out[:, img_slice[0]:img_slice[1]] # only apply on img part
|
||||
|
||||
batch_size = out.shape[0]
|
||||
half_size = batch_size // len(cond_or_uncond)
|
||||
|
||||
ind_neg = cond_or_uncond.index(1)
|
||||
ind_pos = cond_or_uncond.index(0)
|
||||
z_pos = out[half_size * ind_pos:half_size * (ind_pos + 1)]
|
||||
z_neg = out[half_size * ind_neg:half_size * (ind_neg + 1)]
|
||||
|
||||
guided = z_pos * nag_scale - z_neg * (nag_scale - 1.0)
|
||||
|
||||
eps = 1e-6
|
||||
norm_pos = torch.norm(z_pos, p=1, dim=-1, keepdim=True).clamp_min(eps)
|
||||
norm_guided = torch.norm(guided, p=1, dim=-1, keepdim=True).clamp_min(eps)
|
||||
|
||||
ratio = norm_guided / norm_pos
|
||||
scale_factor = torch.minimum(ratio, torch.full_like(ratio, nag_tau)) / ratio
|
||||
|
||||
guided_normalized = guided * scale_factor
|
||||
|
||||
z_final = guided_normalized * nag_alpha + z_pos * (1.0 - nag_alpha)
|
||||
|
||||
if img_slice is not None:
|
||||
orig_out[half_size * ind_neg:half_size * (ind_neg + 1), img_slice[0]:img_slice[1]] = z_final
|
||||
orig_out[half_size * ind_pos:half_size * (ind_pos + 1), img_slice[0]:img_slice[1]] = z_final
|
||||
return orig_out
|
||||
else:
|
||||
out[half_size * ind_pos:half_size * (ind_pos + 1)] = z_final
|
||||
return out
|
||||
|
||||
m.set_model_attn1_output_patch(nag_attention_output_patch)
|
||||
m.disable_model_cfg1_optimization()
|
||||
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class NagExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
NAGuidance,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> NagExtension:
|
||||
return NagExtension()
|
||||
132
comfy_extras/nodes_painter.py
Normal file
132
comfy_extras/nodes_painter.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
from comfy_api.latest import ComfyExtension, io, UI
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
def hex_to_rgb(hex_color: str) -> tuple[float, float, float]:
|
||||
hex_color = hex_color.lstrip("#")
|
||||
if len(hex_color) != 6:
|
||||
return (0.0, 0.0, 0.0)
|
||||
r = int(hex_color[0:2], 16) / 255.0
|
||||
g = int(hex_color[2:4], 16) / 255.0
|
||||
b = int(hex_color[4:6], 16) / 255.0
|
||||
return (r, g, b)
|
||||
|
||||
|
||||
class PainterNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Painter",
|
||||
display_name="Painter",
|
||||
category="image",
|
||||
inputs=[
|
||||
io.Image.Input(
|
||||
"image",
|
||||
optional=True,
|
||||
tooltip="Optional base image to paint over",
|
||||
),
|
||||
io.String.Input(
|
||||
"mask",
|
||||
default="",
|
||||
socketless=True,
|
||||
extra_dict={"widgetType": "PAINTER", "image_upload": True},
|
||||
),
|
||||
io.Int.Input(
|
||||
"width",
|
||||
default=512,
|
||||
min=64,
|
||||
max=4096,
|
||||
step=64,
|
||||
socketless=True,
|
||||
extra_dict={"hidden": True},
|
||||
),
|
||||
io.Int.Input(
|
||||
"height",
|
||||
default=512,
|
||||
min=64,
|
||||
max=4096,
|
||||
step=64,
|
||||
socketless=True,
|
||||
extra_dict={"hidden": True},
|
||||
),
|
||||
io.String.Input(
|
||||
"bg_color",
|
||||
default="#000000",
|
||||
socketless=True,
|
||||
extra_dict={"hidden": True, "widgetType": "COLOR"},
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output("IMAGE"),
|
||||
io.Mask.Output("MASK"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mask, width, height, bg_color="#000000", image=None) -> io.NodeOutput:
|
||||
if image is not None:
|
||||
h, w = image.shape[1], image.shape[2]
|
||||
base_image = image
|
||||
else:
|
||||
h, w = height, width
|
||||
r, g, b = hex_to_rgb(bg_color)
|
||||
base_image = torch.zeros((1, h, w, 3), dtype=torch.float32)
|
||||
base_image[0, :, :, 0] = r
|
||||
base_image[0, :, :, 1] = g
|
||||
base_image[0, :, :, 2] = b
|
||||
|
||||
if mask and mask.strip():
|
||||
mask_path = folder_paths.get_annotated_filepath(mask)
|
||||
painter_img = node_helpers.pillow(Image.open, mask_path)
|
||||
painter_img = painter_img.convert("RGBA")
|
||||
|
||||
if painter_img.size != (w, h):
|
||||
painter_img = painter_img.resize((w, h), Image.LANCZOS)
|
||||
|
||||
painter_np = np.array(painter_img).astype(np.float32) / 255.0
|
||||
painter_rgb = painter_np[:, :, :3]
|
||||
painter_alpha = painter_np[:, :, 3:4]
|
||||
|
||||
mask_tensor = torch.from_numpy(painter_np[:, :, 3]).unsqueeze(0)
|
||||
|
||||
base_np = base_image[0].cpu().numpy()
|
||||
composited = painter_rgb * painter_alpha + base_np * (1.0 - painter_alpha)
|
||||
out_image = torch.from_numpy(composited).unsqueeze(0)
|
||||
else:
|
||||
mask_tensor = torch.zeros((1, h, w), dtype=torch.float32)
|
||||
out_image = base_image
|
||||
|
||||
return io.NodeOutput(out_image, mask_tensor, ui=UI.PreviewImage(out_image))
|
||||
|
||||
@classmethod
|
||||
def fingerprint_inputs(cls, mask, width, height, bg_color="#000000", image=None):
|
||||
if mask and mask.strip():
|
||||
mask_path = folder_paths.get_annotated_filepath(mask)
|
||||
if os.path.exists(mask_path):
|
||||
m = hashlib.sha256()
|
||||
with open(mask_path, "rb") as f:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
class PainterExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self):
|
||||
return [PainterNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint():
|
||||
return PainterExtension()
|
||||
@@ -655,7 +655,6 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
|
||||
batched = batch_masks(values)
|
||||
return io.NodeOutput(batched)
|
||||
|
||||
|
||||
class PostProcessingExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
from comfy_api.latest import ComfyExtension, io, ComfyAPI
|
||||
|
||||
api = ComfyAPI()
|
||||
|
||||
|
||||
async def register_replacements():
|
||||
"""Register all built-in node replacements."""
|
||||
await register_replacements_longeredge()
|
||||
await register_replacements_batchimages()
|
||||
await register_replacements_upscaleimage()
|
||||
await register_replacements_controlnet()
|
||||
await register_replacements_load3d()
|
||||
await register_replacements_preview3d()
|
||||
await register_replacements_svdimg2vid()
|
||||
await register_replacements_conditioningavg()
|
||||
|
||||
async def register_replacements_longeredge():
|
||||
# No dynamic inputs here
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="ImageScaleToMaxDimension",
|
||||
old_node_id="ResizeImagesByLongerEdge",
|
||||
old_widget_ids=["longer_edge"],
|
||||
input_mapping=[
|
||||
{"new_id": "image", "old_id": "images"},
|
||||
{"new_id": "largest_size", "old_id": "longer_edge"},
|
||||
{"new_id": "upscale_method", "set_value": "lanczos"},
|
||||
],
|
||||
# just to test the frontend output_mapping code, does nothing really here
|
||||
output_mapping=[{"new_idx": 0, "old_idx": 0}],
|
||||
))
|
||||
|
||||
async def register_replacements_batchimages():
|
||||
# BatchImages node uses Autogrow
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="BatchImagesNode",
|
||||
old_node_id="ImageBatch",
|
||||
input_mapping=[
|
||||
{"new_id": "images.image0", "old_id": "image1"},
|
||||
{"new_id": "images.image1", "old_id": "image2"},
|
||||
],
|
||||
))
|
||||
|
||||
async def register_replacements_upscaleimage():
|
||||
# ResizeImageMaskNode uses DynamicCombo
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="ResizeImageMaskNode",
|
||||
old_node_id="ImageScaleBy",
|
||||
old_widget_ids=["upscale_method", "scale_by"],
|
||||
input_mapping=[
|
||||
{"new_id": "input", "old_id": "image"},
|
||||
{"new_id": "resize_type", "set_value": "scale by multiplier"},
|
||||
{"new_id": "resize_type.multiplier", "old_id": "scale_by"},
|
||||
{"new_id": "scale_method", "old_id": "upscale_method"},
|
||||
],
|
||||
))
|
||||
|
||||
async def register_replacements_controlnet():
|
||||
# T2IAdapterLoader → ControlNetLoader
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="ControlNetLoader",
|
||||
old_node_id="T2IAdapterLoader",
|
||||
input_mapping=[
|
||||
{"new_id": "control_net_name", "old_id": "t2i_adapter_name"},
|
||||
],
|
||||
))
|
||||
|
||||
async def register_replacements_load3d():
|
||||
# Load3DAnimation merged into Load3D
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="Load3D",
|
||||
old_node_id="Load3DAnimation",
|
||||
))
|
||||
|
||||
async def register_replacements_preview3d():
|
||||
# Preview3DAnimation merged into Preview3D
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="Preview3D",
|
||||
old_node_id="Preview3DAnimation",
|
||||
))
|
||||
|
||||
async def register_replacements_svdimg2vid():
|
||||
# Typo fix: SDV → SVD
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="SVD_img2vid_Conditioning",
|
||||
old_node_id="SDV_img2vid_Conditioning",
|
||||
))
|
||||
|
||||
async def register_replacements_conditioningavg():
|
||||
# Typo fix: trailing space in node name
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="ConditioningAverage",
|
||||
old_node_id="ConditioningAverage ",
|
||||
))
|
||||
|
||||
class NodeReplacementsExtension(ComfyExtension):
|
||||
async def on_load(self) -> None:
|
||||
await register_replacements()
|
||||
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return []
|
||||
|
||||
async def comfy_entrypoint() -> NodeReplacementsExtension:
|
||||
return NodeReplacementsExtension()
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.14.1"
|
||||
__version__ = "0.13.0"
|
||||
|
||||
4
nodes.py
4
nodes.py
@@ -2264,7 +2264,6 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
||||
if not isinstance(extension, ComfyExtension):
|
||||
logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.")
|
||||
return False
|
||||
await extension.on_load()
|
||||
node_list = await extension.get_node_list()
|
||||
if not isinstance(node_list, list):
|
||||
logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.")
|
||||
@@ -2436,8 +2435,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_lora_debug.py",
|
||||
"nodes_color.py",
|
||||
"nodes_toolkit.py",
|
||||
"nodes_replacements.py",
|
||||
"nodes_nag.py",
|
||||
"nodes_painter.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.14.1"
|
||||
version = "0.13.0"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.39.14
|
||||
comfyui-workflow-templates==0.8.43
|
||||
comfyui-frontend-package==1.38.13
|
||||
comfyui-workflow-templates==0.8.38
|
||||
comfyui-embedded-docs==0.4.1
|
||||
torch
|
||||
torchsde
|
||||
|
||||
@@ -40,7 +40,6 @@ from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from app.subgraph_manager import SubgraphManager
|
||||
from app.node_replace_manager import NodeReplaceManager
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from protocol import BinaryEventTypes
|
||||
@@ -205,7 +204,6 @@ class PromptServer():
|
||||
self.model_file_manager = ModelFileManager()
|
||||
self.custom_node_manager = CustomNodeManager()
|
||||
self.subgraph_manager = SubgraphManager()
|
||||
self.node_replace_manager = NodeReplaceManager()
|
||||
self.internal_routes = InternalRoutes(self)
|
||||
self.supports = ["custom_nodes_from_web"]
|
||||
self.prompt_queue = execution.PromptQueue(self)
|
||||
@@ -889,8 +887,6 @@ class PromptServer():
|
||||
if "partial_execution_targets" in json_data:
|
||||
partial_execution_targets = json_data["partial_execution_targets"]
|
||||
|
||||
self.node_replace_manager.apply_replacements(prompt)
|
||||
|
||||
valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
|
||||
extra_data = {}
|
||||
if "extra_data" in json_data:
|
||||
@@ -999,7 +995,6 @@ class PromptServer():
|
||||
self.model_file_manager.add_routes(self.routes)
|
||||
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
||||
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
|
||||
self.node_replace_manager.add_routes(self.routes)
|
||||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||
|
||||
# Prefix every route with /api for easier matching for delegation.
|
||||
|
||||
Reference in New Issue
Block a user