mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-12 16:49:57 +00:00
Compare commits
19 Commits
cbyrne/gls
...
rizz--disp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
583514148a | ||
|
|
2687652530 | ||
|
|
6d11cc7354 | ||
|
|
f262444dd4 | ||
|
|
239ddd3327 | ||
|
|
83dd65f23a | ||
|
|
8ad38d2073 | ||
|
|
6c14f129af | ||
|
|
58dcc97dcf | ||
|
|
19236edfa4 | ||
|
|
73c3f86973 | ||
|
|
262abf437b | ||
|
|
5284e6bf69 | ||
|
|
44f8598521 | ||
|
|
fe52843fe5 | ||
|
|
c39653163d | ||
|
|
18927538a1 | ||
|
|
8a6fbc2dc2 | ||
|
|
b44fc4c589 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -11,7 +11,7 @@ extra_model_paths.yaml
|
||||
/.vs
|
||||
.vscode/
|
||||
.idea/
|
||||
venv/
|
||||
venv*/
|
||||
.venv/
|
||||
/web/extensions/*
|
||||
!/web/extensions/logging.js.example
|
||||
|
||||
@@ -179,8 +179,8 @@ class LLMAdapter(nn.Module):
|
||||
if source_attention_mask.ndim == 2:
|
||||
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
x = self.in_proj(self.embed(target_input_ids))
|
||||
context = source_hidden_states
|
||||
x = self.in_proj(self.embed(target_input_ids, out_dtype=context.dtype))
|
||||
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
|
||||
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
|
||||
position_embeddings = self.rotary_emb(x, position_ids)
|
||||
|
||||
@@ -152,6 +152,7 @@ class Chroma(nn.Module):
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
transformer_options = transformer_options.copy()
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
# running on sequences img
|
||||
@@ -228,6 +229,7 @@ class Chroma(nn.Module):
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if i not in self.skip_dit:
|
||||
|
||||
@@ -196,6 +196,9 @@ class DoubleStreamBlock(nn.Module):
|
||||
else:
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
extra_options = transformer_options.copy()
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||
@@ -224,6 +227,12 @@ class DoubleStreamBlock(nn.Module):
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
for p in patch:
|
||||
attn = p(attn, extra_options)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
@@ -303,6 +312,9 @@ class SingleStreamBlock(nn.Module):
|
||||
else:
|
||||
mod = vec
|
||||
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
extra_options = transformer_options.copy()
|
||||
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
@@ -312,6 +324,12 @@ class SingleStreamBlock(nn.Module):
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
for p in patch:
|
||||
attn = p(attn, extra_options)
|
||||
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
if self.yak_mlp:
|
||||
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
||||
|
||||
@@ -142,6 +142,7 @@ class Flux(nn.Module):
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
|
||||
transformer_options = transformer_options.copy()
|
||||
patches = transformer_options.get("patches", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
@@ -231,6 +232,7 @@ class Flux(nn.Module):
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
|
||||
@@ -304,6 +304,7 @@ class HunyuanVideo(nn.Module):
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
transformer_options = transformer_options.copy()
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
initial_shape = list(img.shape)
|
||||
@@ -416,6 +417,7 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
|
||||
@@ -178,10 +178,7 @@ class BaseModel(torch.nn.Module):
|
||||
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
|
||||
|
||||
context = c_crossattn
|
||||
dtype = self.get_dtype()
|
||||
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
dtype = self.get_dtype_inference()
|
||||
|
||||
xc = xc.to(dtype)
|
||||
device = xc.device
|
||||
@@ -218,6 +215,13 @@ class BaseModel(torch.nn.Module):
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
|
||||
def get_dtype_inference(self):
|
||||
dtype = self.get_dtype()
|
||||
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
return dtype
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
|
||||
@@ -372,9 +376,7 @@ class BaseModel(torch.nn.Module):
|
||||
input_shapes += shape
|
||||
|
||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||
dtype = self.get_dtype()
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
dtype = self.get_dtype_inference()
|
||||
#TODO: this needs to be tweaked
|
||||
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
||||
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
||||
@@ -1165,7 +1167,7 @@ class Anima(BaseModel):
|
||||
t5xxl_ids = t5xxl_ids.unsqueeze(0)
|
||||
|
||||
if torch.is_inference_mode_enabled(): # if not we are training
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference()))
|
||||
else:
|
||||
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
|
||||
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
||||
|
||||
@@ -406,13 +406,16 @@ class ModelPatcher:
|
||||
def memory_required(self, input_shape):
|
||||
return self.model.memory_required(input_shape=input_shape)
|
||||
|
||||
def disable_model_cfg1_optimization(self):
|
||||
self.model_options["disable_cfg1_optimization"] = True
|
||||
|
||||
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||
else:
|
||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||
if disable_cfg1_optimization:
|
||||
self.model_options["disable_cfg1_optimization"] = True
|
||||
self.disable_model_cfg1_optimization()
|
||||
|
||||
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
||||
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
||||
|
||||
17
comfy/ops.py
17
comfy/ops.py
@@ -79,7 +79,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
||||
offload_stream = None
|
||||
xfer_dest = None
|
||||
|
||||
@@ -170,10 +170,10 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
|
||||
x = lowvram_fn(x)
|
||||
if (isinstance(orig, QuantizedTensor) and
|
||||
(orig.dtype == dtype and len(fns) == 0 or update_weight)):
|
||||
(want_requant and len(fns) == 0 or update_weight)):
|
||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||
if orig.dtype == dtype and len(fns) == 0:
|
||||
if want_requant and len(fns) == 0:
|
||||
#The layer actually wants our freshly saved QT
|
||||
x = y
|
||||
elif update_weight:
|
||||
@@ -194,7 +194,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
||||
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||
# will add async-offload support to your cast and improve performance.
|
||||
@@ -212,7 +212,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
if hasattr(s, "_v"):
|
||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
|
||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
||||
|
||||
if offloadable and (device != s.weight.device or
|
||||
(s.bias is not None and device != s.bias.device)):
|
||||
@@ -850,8 +850,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward_comfy_cast_weights(self, input, compute_dtype=None):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
|
||||
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
|
||||
x = self._forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
@@ -881,8 +881,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||
|
||||
|
||||
output = self.forward_comfy_cast_weights(input, compute_dtype)
|
||||
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
|
||||
|
||||
# Reshape output back to 3D if input was 3D
|
||||
if reshaped_3d:
|
||||
|
||||
29
comfy/sd.py
29
comfy/sd.py
@@ -423,6 +423,19 @@ class CLIP:
|
||||
def get_key_patches(self):
|
||||
return self.patcher.get_key_patches()
|
||||
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
|
||||
if self.layer_idx is not None:
|
||||
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
|
||||
|
||||
self.load_model()
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=True):
|
||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
class VAE:
|
||||
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
@@ -1182,6 +1195,7 @@ class TEModel(Enum):
|
||||
JINA_CLIP_2 = 19
|
||||
QWEN3_8B = 20
|
||||
QWEN3_06B = 21
|
||||
GEMMA_3_4B_VISION = 22
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@@ -1210,7 +1224,10 @@ def detect_te_model(sd):
|
||||
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.GEMMA_3_12B
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.GEMMA_3_4B
|
||||
if 'vision_model.embeddings.patch_embedding.weight' in sd:
|
||||
return TEModel.GEMMA_3_4B_VISION
|
||||
else:
|
||||
return TEModel.GEMMA_3_4B
|
||||
return TEModel.GEMMA_2_2B
|
||||
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||
weight = sd['model.layers.0.self_attn.k_proj.bias']
|
||||
@@ -1270,6 +1287,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
else:
|
||||
if "text_projection" in clip_data[i]:
|
||||
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
|
||||
if "lm_head.weight" in clip_data[i]:
|
||||
clip_data[i]["model.lm_head.weight"] = clip_data[i].pop("lm_head.weight") # prefix missing in some models
|
||||
|
||||
tokenizer_data = {}
|
||||
clip_target = EmptyClass()
|
||||
@@ -1335,6 +1354,14 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model == TEModel.GEMMA_3_4B_VISION:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b_vision")
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model == TEModel.GEMMA_3_12B:
|
||||
clip_target.clip = comfy.text_encoders.lt.gemma3_te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lt.Gemma3_12BTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model == TEModel.LLAMA3_8:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
|
||||
|
||||
@@ -308,6 +308,15 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
def load_sd(self, sd):
|
||||
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[]):
|
||||
if isinstance(tokens, dict):
|
||||
tokens_only = next(iter(tokens.values())) # todo: get this better?
|
||||
else:
|
||||
tokens_only = tokens
|
||||
tokens_only = [[t[0] for t in b] for b in tokens_only]
|
||||
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens)
|
||||
|
||||
def parse_parentheses(string):
|
||||
result = []
|
||||
current_item = ""
|
||||
@@ -663,6 +672,9 @@ class SDTokenizer:
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=True):
|
||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
class SD1Tokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None):
|
||||
if name is not None:
|
||||
@@ -686,6 +698,9 @@ class SD1Tokenizer:
|
||||
def state_dict(self):
|
||||
return getattr(self, self.clip).state_dict()
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=True):
|
||||
return getattr(self, self.clip).decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
class SD1CheckpointClipModel(SDClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
|
||||
@@ -722,3 +737,6 @@ class SD1ClipModel(torch.nn.Module):
|
||||
|
||||
def load_sd(self, sd):
|
||||
return getattr(self, self.clip).load_sd(sd)
|
||||
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
||||
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
||||
|
||||
@@ -3,6 +3,8 @@ import torch.nn as nn
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any, Tuple
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
import comfy.utils
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.model_management
|
||||
@@ -313,6 +315,13 @@ class Gemma3_4B_Config:
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
GEMMA3_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
|
||||
|
||||
@dataclass
|
||||
class Gemma3_4B_Vision_Config(Gemma3_4B_Config):
|
||||
vision_config = GEMMA3_VISION_CONFIG
|
||||
mm_tokens_per_image = 256
|
||||
|
||||
@dataclass
|
||||
class Gemma3_12B_Config:
|
||||
vocab_size: int = 262208
|
||||
@@ -336,7 +345,7 @@ class Gemma3_12B_Config:
|
||||
rope_scale = [8.0, 1.0]
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
|
||||
vision_config = GEMMA3_VISION_CONFIG
|
||||
mm_tokens_per_image = 256
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
@@ -441,8 +450,10 @@ class Attention(nn.Module):
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
optimized_attention=None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
):
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
xq = self.q_proj(hidden_states)
|
||||
xk = self.k_proj(hidden_states)
|
||||
xv = self.v_proj(hidden_states)
|
||||
@@ -477,6 +488,11 @@ class Attention(nn.Module):
|
||||
else:
|
||||
present_key_value = (xk, xv, index + num_tokens)
|
||||
|
||||
if sliding_window is not None and xk.shape[2] > sliding_window:
|
||||
xk = xk[:, :, -sliding_window:]
|
||||
xv = xv[:, :, -sliding_window:]
|
||||
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
|
||||
|
||||
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
|
||||
@@ -559,10 +575,12 @@ class TransformerBlockGemma2(nn.Module):
|
||||
optimized_attention=None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
):
|
||||
sliding_window = None
|
||||
if self.transformer_type == 'gemma3':
|
||||
if self.sliding_attention:
|
||||
sliding_window = self.sliding_attention
|
||||
if x.shape[1] > self.sliding_attention:
|
||||
sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
|
||||
sliding_mask = torch.full((x.shape[1], x.shape[1]), torch.finfo(x.dtype).min, device=x.device, dtype=x.dtype)
|
||||
sliding_mask.tril_(diagonal=-self.sliding_attention)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask + sliding_mask
|
||||
@@ -581,6 +599,7 @@ class TransformerBlockGemma2(nn.Module):
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
past_key_value=past_key_value,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
|
||||
x = self.post_attention_layernorm(x)
|
||||
@@ -765,6 +784,104 @@ class BaseLlama:
|
||||
def forward(self, input_ids, *args, **kwargs):
|
||||
return self.model(input_ids, *args, **kwargs)
|
||||
|
||||
class BaseGenerate:
|
||||
def logits(self, x):
|
||||
input = x[:, -1:]
|
||||
if hasattr(self.model, "lm_head"):
|
||||
module = self.model.lm_head
|
||||
else:
|
||||
module = self.model.embed_tokens
|
||||
|
||||
offload_stream = None
|
||||
if module.comfy_cast_weights:
|
||||
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
|
||||
else:
|
||||
weight = self.model.embed_tokens.weight.to(x)
|
||||
|
||||
x = torch.nn.functional.linear(input, weight, None)
|
||||
|
||||
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
|
||||
return x
|
||||
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=[], initial_tokens=[], execution_dtype=None, min_tokens=0):
|
||||
device = embeds.device
|
||||
model_config = self.model.config
|
||||
|
||||
if execution_dtype is None:
|
||||
if comfy.model_management.should_use_bf16(device):
|
||||
execution_dtype = torch.bfloat16
|
||||
else:
|
||||
execution_dtype = torch.float32
|
||||
embeds = embeds.to(execution_dtype)
|
||||
|
||||
if embeds.ndim == 2:
|
||||
embeds = embeds.unsqueeze(0)
|
||||
|
||||
past_key_values = [] #kv_cache init
|
||||
max_cache_len = embeds.shape[1] + max_length
|
||||
for x in range(model_config.num_hidden_layers):
|
||||
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
|
||||
|
||||
generated_token_ids = []
|
||||
pbar = comfy.utils.ProgressBar(max_length)
|
||||
|
||||
# Generation loop
|
||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
|
||||
logits = self.logits(x)[:, -1]
|
||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
|
||||
token_id = next_token[0].item()
|
||||
generated_token_ids.append(token_id)
|
||||
|
||||
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
|
||||
pbar.update(1)
|
||||
|
||||
if token_id in stop_tokens:
|
||||
break
|
||||
|
||||
return generated_token_ids
|
||||
|
||||
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
|
||||
|
||||
if not do_sample or temperature == 0.0:
|
||||
return torch.argmax(logits, dim=-1, keepdim=True)
|
||||
|
||||
# Sampling mode
|
||||
if repetition_penalty != 1.0:
|
||||
for i in range(logits.shape[0]):
|
||||
for token_id in set(token_history):
|
||||
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
|
||||
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
|
||||
if top_k > 0:
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = torch.finfo(logits.dtype).min
|
||||
|
||||
if min_p > 0.0:
|
||||
probs_before_filter = torch.nn.functional.softmax(logits, dim=-1)
|
||||
top_probs, _ = probs_before_filter.max(dim=-1, keepdim=True)
|
||||
min_threshold = min_p * top_probs
|
||||
indices_to_remove = probs_before_filter < min_threshold
|
||||
logits[indices_to_remove] = torch.finfo(logits.dtype).min
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[..., 0] = False
|
||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
||||
indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = torch.finfo(logits.dtype).min
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
|
||||
return torch.multinomial(probs, num_samples=1, generator=generator)
|
||||
|
||||
class BaseQwen3:
|
||||
def logits(self, x):
|
||||
input = x[:, -1:]
|
||||
@@ -871,7 +988,7 @@ class Ovis25_2B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||
class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen25_7BVLI_Config(**config_dict)
|
||||
@@ -881,6 +998,9 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
# todo: should this be tied or not?
|
||||
#self.lm_head = operations.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
if embed["type"] == "image":
|
||||
image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
|
||||
@@ -923,7 +1043,7 @@ class Gemma2_2B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Gemma3_4B(BaseLlama, torch.nn.Module):
|
||||
class Gemma3_4B(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Gemma3_4B_Config(**config_dict)
|
||||
@@ -932,7 +1052,25 @@ class Gemma3_4B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Gemma3_12B(BaseLlama, torch.nn.Module):
|
||||
class Gemma3_4B_Vision(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Gemma3_4B_Vision_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations)
|
||||
self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations)
|
||||
self.image_size = config.vision_config["image_size"]
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
if embed["type"] == "image":
|
||||
image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True)
|
||||
return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None
|
||||
return None, None
|
||||
|
||||
class Gemma3_12B(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Gemma3_12B_Config(**config_dict)
|
||||
|
||||
@@ -6,6 +6,7 @@ import comfy.text_encoders.genmo
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
import torch
|
||||
import comfy.utils
|
||||
import math
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@@ -22,40 +23,79 @@ def ltxv_te(*args, **kwargs):
|
||||
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
|
||||
|
||||
|
||||
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||
|
||||
class Gemma3_Tokenizer():
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, image=None, llama_template=None, skip_template=True, **kwargs):
|
||||
self.llama_template = "<start_of_turn>system\nYou are a helpful assistant.<end_of_turn>\n<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n"
|
||||
self.llama_template_images = "<start_of_turn>system\nYou are a helpful assistant.<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>{}<end_of_turn>\n\n<start_of_turn>model\n"
|
||||
|
||||
if image is None:
|
||||
images = []
|
||||
else:
|
||||
samples = image.movedim(-1, 1)
|
||||
total = int(896 * 896)
|
||||
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
width = round(samples.shape[3] * scale_by)
|
||||
height = round(samples.shape[2] * scale_by)
|
||||
|
||||
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1)
|
||||
images = [s[:, :, :, :3]]
|
||||
|
||||
if text.startswith('<start_of_turn>'):
|
||||
skip_template = True
|
||||
|
||||
if skip_template:
|
||||
llama_text = text
|
||||
else:
|
||||
if llama_template is None:
|
||||
if len(images) > 0:
|
||||
llama_text = self.llama_template_images.format(text)
|
||||
else:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
|
||||
text_tokens = super().tokenize_with_weights(llama_text, return_word_ids)
|
||||
|
||||
if len(images) > 0:
|
||||
embed_count = 0
|
||||
for r in text_tokens:
|
||||
for i, token in enumerate(r):
|
||||
if token[0] == 262144 and embed_count < len(images):
|
||||
r[i] = ({"type": "image", "data": images[embed_count]},) + token[1:]
|
||||
embed_count += 1
|
||||
return text_tokens
|
||||
|
||||
class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer)
|
||||
|
||||
|
||||
class Gemma3_12BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
self.dtypes = set()
|
||||
self.dtypes.add(dtype)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs):
|
||||
text = llama_template.format(text)
|
||||
text_tokens = super().tokenize_with_weights(text, return_word_ids)
|
||||
embed_count = 0
|
||||
for k in text_tokens:
|
||||
tt = text_tokens[k]
|
||||
for r in tt:
|
||||
for i in range(len(r)):
|
||||
if r[i][0] == 262144:
|
||||
if image_embeds is not None and embed_count < image_embeds.shape[0]:
|
||||
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:]
|
||||
embed_count += 1
|
||||
return text_tokens
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
tokens_only = [[t[0] for t in b] for b in tokens]
|
||||
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
|
||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
|
||||
|
||||
class LTXAVTEModel(torch.nn.Module):
|
||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||
@@ -112,6 +152,9 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
|
||||
return out.to(out_device), pooled
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
||||
return self.gemma3_12b.load_sd(sd)
|
||||
@@ -152,3 +195,14 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
dtype = dtype_llama
|
||||
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||
return LTXAVTEModel_
|
||||
|
||||
def gemma3_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class Gemma3_12BModel_(Gemma3_12BModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return Gemma3_12BModel_
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
from comfy import sd1_clip
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
import comfy.text_encoders.llama
|
||||
|
||||
from comfy.text_encoders.lt import Gemma3_Tokenizer
|
||||
import comfy.utils
|
||||
|
||||
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||
special_tokens = {"<end_of_turn>": 107}
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
|
||||
class Gemma3_4BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, disable_weights=True, tokenizer_data=tokenizer_data)
|
||||
|
||||
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@@ -31,6 +31,9 @@ class Gemma2_2BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
def generate(self, embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
return super().generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[107])
|
||||
|
||||
class Gemma3_4BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
@@ -40,6 +43,23 @@ class Gemma3_4BModel(sd1_clip.SDClipModel):
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
def generate(self, embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
return super().generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106])
|
||||
|
||||
class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
def process_tokens(self, tokens, device):
|
||||
embeds, _, _, embeds_info = super().process_tokens(tokens, device)
|
||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||
return embeds
|
||||
|
||||
class LuminaModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel):
|
||||
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
||||
@@ -50,6 +70,8 @@ def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b
|
||||
model = Gemma2_2BModel
|
||||
elif model_type == "gemma3_4b":
|
||||
model = Gemma3_4BModel
|
||||
elif model_type == "gemma3_4b_vision":
|
||||
model = Gemma3_4B_Vision_Model
|
||||
|
||||
class LuminaTEModel_(LuminaModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
|
||||
@@ -6,9 +6,10 @@ class SPieceTokenizer:
|
||||
def from_pretrained(path, **kwargs):
|
||||
return SPieceTokenizer(path, **kwargs)
|
||||
|
||||
def __init__(self, tokenizer_path, add_bos=False, add_eos=True):
|
||||
def __init__(self, tokenizer_path, add_bos=False, add_eos=True, special_tokens=None):
|
||||
self.add_bos = add_bos
|
||||
self.add_eos = add_eos
|
||||
self.special_tokens = special_tokens
|
||||
import sentencepiece
|
||||
if torch.is_tensor(tokenizer_path):
|
||||
tokenizer_path = tokenizer_path.numpy().tobytes()
|
||||
@@ -27,8 +28,32 @@ class SPieceTokenizer:
|
||||
return out
|
||||
|
||||
def __call__(self, string):
|
||||
if self.special_tokens is not None:
|
||||
import re
|
||||
special_tokens_pattern = '|'.join(re.escape(token) for token in self.special_tokens.keys())
|
||||
if special_tokens_pattern and re.search(special_tokens_pattern, string):
|
||||
parts = re.split(f'({special_tokens_pattern})', string)
|
||||
result = []
|
||||
for part in parts:
|
||||
if not part:
|
||||
continue
|
||||
if part in self.special_tokens:
|
||||
result.append(self.special_tokens[part])
|
||||
else:
|
||||
encoded = self.tokenizer.encode(part, add_bos=False, add_eos=False)
|
||||
result.extend(encoded)
|
||||
return {"input_ids": result}
|
||||
|
||||
out = self.tokenizer.encode(string)
|
||||
return {"input_ids": out}
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=False):
|
||||
|
||||
if skip_special_tokens and self.special_tokens:
|
||||
special_token_ids = set(self.special_tokens.values())
|
||||
token_ids = [tid for tid in token_ids if tid not in special_token_ids]
|
||||
|
||||
return self.tokenizer.decode(token_ids)
|
||||
|
||||
def serialize_model(self):
|
||||
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))
|
||||
|
||||
@@ -1418,3 +1418,11 @@ def deepcopy_list_dict(obj, memo=None):
|
||||
|
||||
memo[obj_id] = res
|
||||
return res
|
||||
|
||||
def normalize_image_embeddings(embeds, embeds_info, scale_factor):
|
||||
"""Normalize image embeddings to match text embedding scale"""
|
||||
for info in embeds_info:
|
||||
if info.get("type") == "image":
|
||||
start_idx = info["index"]
|
||||
end_idx = start_idx + info["size"]
|
||||
embeds[:, start_idx:end_idx, :] /= scale_factor
|
||||
|
||||
@@ -75,6 +75,12 @@ class NumberDisplay(str, Enum):
|
||||
slider = "slider"
|
||||
|
||||
|
||||
class ControlAfterGenerate(str, Enum):
|
||||
fixed = "fixed"
|
||||
increment = "increment"
|
||||
decrement = "decrement"
|
||||
randomize = "randomize"
|
||||
|
||||
class _ComfyType(ABC):
|
||||
Type = Any
|
||||
io_type: str = None
|
||||
@@ -263,7 +269,7 @@ class Int(ComfyTypeIO):
|
||||
class Input(WidgetInput):
|
||||
'''Integer input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
|
||||
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool | ControlAfterGenerate=None,
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
||||
self.min = min
|
||||
@@ -345,7 +351,7 @@ class Combo(ComfyTypeIO):
|
||||
tooltip: str=None,
|
||||
lazy: bool=None,
|
||||
default: str | int | Enum = None,
|
||||
control_after_generate: bool=None,
|
||||
control_after_generate: bool | ControlAfterGenerate=None,
|
||||
upload: UploadType=None,
|
||||
image_folder: FolderType=None,
|
||||
remote: RemoteOptions=None,
|
||||
@@ -389,7 +395,7 @@ class MultiCombo(ComfyTypeI):
|
||||
Type = list[str]
|
||||
class Input(Combo.Input):
|
||||
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
|
||||
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool | ControlAfterGenerate=None,
|
||||
socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced)
|
||||
self.multiselect = True
|
||||
@@ -1203,6 +1209,30 @@ class Color(ComfyTypeIO):
|
||||
def as_dict(self):
|
||||
return super().as_dict()
|
||||
|
||||
@comfytype(io_type="BOUNDING_BOX")
|
||||
class BoundingBox(ComfyTypeIO):
|
||||
class BoundingBoxDict(TypedDict):
|
||||
x: int
|
||||
y: int
|
||||
width: int
|
||||
height: int
|
||||
Type = BoundingBoxDict
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, default: dict=None, component: str=None):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless)
|
||||
self.component = component
|
||||
if default is None:
|
||||
self.default = {"x": 0, "y": 0, "width": 512, "height": 512}
|
||||
|
||||
def as_dict(self):
|
||||
d = super().as_dict()
|
||||
if self.component:
|
||||
d["component"] = self.component
|
||||
return d
|
||||
|
||||
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
||||
@@ -2097,6 +2127,7 @@ __all__ = [
|
||||
"UploadType",
|
||||
"RemoteOptions",
|
||||
"NumberDisplay",
|
||||
"ControlAfterGenerate",
|
||||
|
||||
"comfytype",
|
||||
"Custom",
|
||||
@@ -2183,5 +2214,6 @@ __all__ = [
|
||||
"ImageCompare",
|
||||
"PriceBadgeDepends",
|
||||
"PriceBadge",
|
||||
"BoundingBox",
|
||||
"NodeReplace",
|
||||
]
|
||||
|
||||
@@ -116,9 +116,15 @@ class GeminiGenerationConfig(BaseModel):
|
||||
topP: float | None = Field(None, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class GeminiImageOutputOptions(BaseModel):
|
||||
mimeType: str = Field("image/png")
|
||||
compressionQuality: int | None = Field(None)
|
||||
|
||||
|
||||
class GeminiImageConfig(BaseModel):
|
||||
aspectRatio: str | None = Field(None)
|
||||
imageSize: str | None = Field(None)
|
||||
imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions)
|
||||
|
||||
|
||||
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
||||
|
||||
@@ -198,11 +198,6 @@ dict_recraft_substyles_v3 = {
|
||||
}
|
||||
|
||||
|
||||
class RecraftModel(str, Enum):
|
||||
recraftv3 = 'recraftv3'
|
||||
recraftv2 = 'recraftv2'
|
||||
|
||||
|
||||
class RecraftImageSize(str, Enum):
|
||||
res_1024x1024 = '1024x1024'
|
||||
res_1365x1024 = '1365x1024'
|
||||
@@ -221,6 +216,41 @@ class RecraftImageSize(str, Enum):
|
||||
res_1707x1024 = '1707x1024'
|
||||
|
||||
|
||||
RECRAFT_V4_SIZES = [
|
||||
"1024x1024",
|
||||
"1536x768",
|
||||
"768x1536",
|
||||
"1280x832",
|
||||
"832x1280",
|
||||
"1216x896",
|
||||
"896x1216",
|
||||
"1152x896",
|
||||
"896x1152",
|
||||
"832x1344",
|
||||
"1280x896",
|
||||
"896x1280",
|
||||
"1344x768",
|
||||
"768x1344",
|
||||
]
|
||||
|
||||
RECRAFT_V4_PRO_SIZES = [
|
||||
"2048x2048",
|
||||
"3072x1536",
|
||||
"1536x3072",
|
||||
"2560x1664",
|
||||
"1664x2560",
|
||||
"2432x1792",
|
||||
"1792x2432",
|
||||
"2304x1792",
|
||||
"1792x2304",
|
||||
"1664x2688",
|
||||
"1434x1024",
|
||||
"1024x1434",
|
||||
"2560x1792",
|
||||
"1792x2560",
|
||||
]
|
||||
|
||||
|
||||
class RecraftColorObject(BaseModel):
|
||||
rgb: list[int] = Field(..., description='An array of 3 integer values in range of 0...255 defining RGB Color Model')
|
||||
|
||||
@@ -234,17 +264,16 @@ class RecraftControlsObject(BaseModel):
|
||||
|
||||
class RecraftImageGenerationRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt describing the image to generate')
|
||||
size: RecraftImageSize | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
|
||||
size: str | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
|
||||
n: int = Field(..., description='The number of images to generate')
|
||||
negative_prompt: str | None = Field(None, description='A text description of undesired elements on an image')
|
||||
model: RecraftModel | None = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
|
||||
model: str = Field(...)
|
||||
style: str | None = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
|
||||
substyle: str | None = Field(None, description='The substyle to apply to the generated image, depending on the style input')
|
||||
controls: RecraftControlsObject | None = Field(None, description='A set of custom parameters to tweak generation process')
|
||||
style_id: str | None = Field(None, description='Use a previously uploaded style as a reference; UUID')
|
||||
strength: float | None = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
|
||||
random_seed: int | None = Field(None, description="Seed for video generation")
|
||||
# text_layout
|
||||
|
||||
|
||||
class RecraftReturnedObject(BaseModel):
|
||||
|
||||
@@ -6,6 +6,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
|
||||
import base64
|
||||
import os
|
||||
from enum import Enum
|
||||
from fnmatch import fnmatch
|
||||
from io import BytesIO
|
||||
from typing import Literal
|
||||
|
||||
@@ -119,6 +120,13 @@ async def create_image_parts(
|
||||
return image_parts
|
||||
|
||||
|
||||
def _mime_matches(mime: GeminiMimeType | None, pattern: str) -> bool:
|
||||
"""Check if a MIME type matches a pattern. Supports fnmatch globs (e.g. 'image/*')."""
|
||||
if mime is None:
|
||||
return False
|
||||
return fnmatch(mime.value, pattern)
|
||||
|
||||
|
||||
def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]:
|
||||
"""
|
||||
Filter response parts by their type.
|
||||
@@ -151,9 +159,9 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
|
||||
for part in candidate.content.parts:
|
||||
if part_type == "text" and part.text:
|
||||
parts.append(part)
|
||||
elif part.inlineData and part.inlineData.mimeType == part_type:
|
||||
elif part.inlineData and _mime_matches(part.inlineData.mimeType, part_type):
|
||||
parts.append(part)
|
||||
elif part.fileData and part.fileData.mimeType == part_type:
|
||||
elif part.fileData and _mime_matches(part.fileData.mimeType, part_type):
|
||||
parts.append(part)
|
||||
|
||||
if not parts and blocked_reasons:
|
||||
@@ -178,7 +186,7 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
||||
|
||||
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
||||
image_tensors: list[Input.Image] = []
|
||||
parts = get_parts_by_type(response, "image/png")
|
||||
parts = get_parts_by_type(response, "image/*")
|
||||
for part in parts:
|
||||
if part.inlineData:
|
||||
image_data = base64.b64decode(part.inlineData.data)
|
||||
@@ -626,7 +634,7 @@ class GeminiImage(IO.ComfyNode):
|
||||
|
||||
if not aspect_ratio:
|
||||
aspect_ratio = "auto" # for backward compatability with old workflows; to-do remove this in December
|
||||
image_config = GeminiImageConfig(aspectRatio=aspect_ratio)
|
||||
image_config = GeminiImageConfig() if aspect_ratio == "auto" else GeminiImageConfig(aspectRatio=aspect_ratio)
|
||||
|
||||
if images is not None:
|
||||
parts.extend(await create_image_parts(cls, images))
|
||||
@@ -646,7 +654,7 @@ class GeminiImage(IO.ComfyNode):
|
||||
],
|
||||
generationConfig=GeminiImageGenerationConfig(
|
||||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||
imageConfig=None if aspect_ratio == "auto" else image_config,
|
||||
imageConfig=image_config,
|
||||
),
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
|
||||
@@ -52,7 +52,7 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TencentTextToModelNode",
|
||||
display_name="Hunyuan3D: Text to Model",
|
||||
display_name="Text to 3D model",
|
||||
category="api node/3d/Tencent",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@@ -166,7 +166,7 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TencentImageToModelNode",
|
||||
display_name="Hunyuan3D: Image(s) to Model",
|
||||
display_name="Image to 3D Model",
|
||||
category="api node/3d/Tencent",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
|
||||
@@ -2260,7 +2260,7 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode):
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingLipSyncAudioToVideoNode",
|
||||
display_name="Kling Lip Sync Video with Audio",
|
||||
display_name="Lipsync",
|
||||
category="api node/video/Kling",
|
||||
description="Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.",
|
||||
inputs=[
|
||||
|
||||
@@ -573,7 +573,7 @@ class OpenAIChatNode(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIChatNode",
|
||||
display_name="OpenAI ChatGPT",
|
||||
display_name="Text generation (LLM)",
|
||||
category="api node/text/OpenAI",
|
||||
description="Generate text responses from an OpenAI model.",
|
||||
inputs=[
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from io import BytesIO
|
||||
from typing import Optional, Union
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
@@ -9,6 +8,8 @@ from typing_extensions import override
|
||||
from comfy.utils import ProgressBar
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.recraft import (
|
||||
RECRAFT_V4_PRO_SIZES,
|
||||
RECRAFT_V4_SIZES,
|
||||
RecraftColor,
|
||||
RecraftColorChain,
|
||||
RecraftControls,
|
||||
@@ -18,7 +19,6 @@ from comfy_api_nodes.apis.recraft import (
|
||||
RecraftImageGenerationResponse,
|
||||
RecraftImageSize,
|
||||
RecraftIO,
|
||||
RecraftModel,
|
||||
RecraftStyle,
|
||||
RecraftStyleV3,
|
||||
get_v3_substyles,
|
||||
@@ -39,7 +39,7 @@ async def handle_recraft_file_request(
|
||||
cls: type[IO.ComfyNode],
|
||||
image: torch.Tensor,
|
||||
path: str,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
mask: torch.Tensor | None = None,
|
||||
total_pixels: int = 4096 * 4096,
|
||||
timeout: int = 1024,
|
||||
request=None,
|
||||
@@ -73,11 +73,11 @@ async def handle_recraft_file_request(
|
||||
def recraft_multipart_parser(
|
||||
data,
|
||||
parent_key=None,
|
||||
formatter: Optional[type[callable]] = None,
|
||||
converted_to_check: Optional[list[list]] = None,
|
||||
formatter: type[callable] | None = None,
|
||||
converted_to_check: list[list] | None = None,
|
||||
is_list: bool = False,
|
||||
return_mode: str = "formdata", # "dict" | "formdata"
|
||||
) -> Union[dict, aiohttp.FormData]:
|
||||
) -> dict | aiohttp.FormData:
|
||||
"""
|
||||
Formats data such that multipart/form-data will work with aiohttp library when both files and data are present.
|
||||
|
||||
@@ -309,7 +309,7 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode):
|
||||
node_id="RecraftStyleV3InfiniteStyleLibrary",
|
||||
display_name="Recraft Style - Infinite Style Library",
|
||||
category="api node/image/Recraft",
|
||||
description="Select style based on preexisting UUID from Recraft's Infinite Style Library.",
|
||||
description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.",
|
||||
inputs=[
|
||||
IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."),
|
||||
],
|
||||
@@ -485,7 +485,7 @@ class RecraftTextToImageNode(IO.ComfyNode):
|
||||
data=RecraftImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
model=RecraftModel.recraftv3,
|
||||
model="recraftv3",
|
||||
size=size,
|
||||
n=n,
|
||||
style=recraft_style.style,
|
||||
@@ -598,7 +598,7 @@ class RecraftImageToImageNode(IO.ComfyNode):
|
||||
request = RecraftImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
model=RecraftModel.recraftv3,
|
||||
model="recraftv3",
|
||||
n=n,
|
||||
strength=round(strength, 2),
|
||||
style=recraft_style.style,
|
||||
@@ -698,7 +698,7 @@ class RecraftImageInpaintingNode(IO.ComfyNode):
|
||||
request = RecraftImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
model=RecraftModel.recraftv3,
|
||||
model="recraftv3",
|
||||
n=n,
|
||||
style=recraft_style.style,
|
||||
substyle=recraft_style.substyle,
|
||||
@@ -810,7 +810,7 @@ class RecraftTextToVectorNode(IO.ComfyNode):
|
||||
data=RecraftImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
model=RecraftModel.recraftv3,
|
||||
model="recraftv3",
|
||||
size=size,
|
||||
n=n,
|
||||
style=recraft_style.style,
|
||||
@@ -933,7 +933,7 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode):
|
||||
request = RecraftImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
model=RecraftModel.recraftv3,
|
||||
model="recraftv3",
|
||||
n=n,
|
||||
style=recraft_style.style,
|
||||
substyle=recraft_style.substyle,
|
||||
@@ -961,7 +961,7 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RecraftRemoveBackgroundNode",
|
||||
display_name="Recraft Remove Background",
|
||||
display_name="Remove Background",
|
||||
category="api node/image/Recraft",
|
||||
description="Remove background from image, and return processed image and mask.",
|
||||
inputs=[
|
||||
@@ -1078,6 +1078,252 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode):
|
||||
)
|
||||
|
||||
|
||||
class RecraftV4TextToImageNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RecraftV4TextToImageNode",
|
||||
display_name="Recraft V4 Text to Image",
|
||||
category="api node/image/Recraft",
|
||||
description="Generates images using Recraft V4 or V4 Pro models.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Prompt for the image generation. Maximum 10,000 characters.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
tooltip="An optional text description of undesired elements on an image.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"recraftv4",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
options=RECRAFT_V4_SIZES,
|
||||
default="1024x1024",
|
||||
tooltip="The size of the generated image.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"recraftv4_pro",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
options=RECRAFT_V4_PRO_SIZES,
|
||||
default="2048x2048",
|
||||
tooltip="The size of the generated image.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="The model to use for generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"n",
|
||||
default=1,
|
||||
min=1,
|
||||
max=6,
|
||||
tooltip="The number of images to generate.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
IO.Custom(RecraftIO.CONTROLS).Input(
|
||||
"recraft_controls",
|
||||
tooltip="Optional additional controls over the generation via the Recraft Controls node.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {"recraftv4": 0.04, "recraftv4_pro": 0.25};
|
||||
{"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
model: dict,
|
||||
n: int,
|
||||
seed: int,
|
||||
recraft_controls: RecraftControls | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"),
|
||||
response_model=RecraftImageGenerationResponse,
|
||||
data=RecraftImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
model=model["model"],
|
||||
size=model["size"],
|
||||
n=n,
|
||||
controls=recraft_controls.create_api_model() if recraft_controls else None,
|
||||
),
|
||||
max_retries=1,
|
||||
)
|
||||
images = []
|
||||
for data in response.data:
|
||||
with handle_recraft_image_output():
|
||||
image = bytesio_to_image_tensor(await download_url_as_bytesio(data.url, timeout=1024))
|
||||
if len(image.shape) < 4:
|
||||
image = image.unsqueeze(0)
|
||||
images.append(image)
|
||||
return IO.NodeOutput(torch.cat(images, dim=0))
|
||||
|
||||
|
||||
class RecraftV4TextToVectorNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RecraftV4TextToVectorNode",
|
||||
display_name="Recraft V4 Text to Vector",
|
||||
category="api node/image/Recraft",
|
||||
description="Generates SVG using Recraft V4 or V4 Pro models.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Prompt for the image generation. Maximum 10,000 characters.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
tooltip="An optional text description of undesired elements on an image.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"recraftv4",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
options=RECRAFT_V4_SIZES,
|
||||
default="1024x1024",
|
||||
tooltip="The size of the generated image.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"recraftv4_pro",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
options=RECRAFT_V4_PRO_SIZES,
|
||||
default="2048x2048",
|
||||
tooltip="The size of the generated image.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="The model to use for generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"n",
|
||||
default=1,
|
||||
min=1,
|
||||
max=6,
|
||||
tooltip="The number of images to generate.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
IO.Custom(RecraftIO.CONTROLS).Input(
|
||||
"recraft_controls",
|
||||
tooltip="Optional additional controls over the generation via the Recraft Controls node.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.SVG.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {"recraftv4": 0.08, "recraftv4_pro": 0.30};
|
||||
{"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
model: dict,
|
||||
n: int,
|
||||
seed: int,
|
||||
recraft_controls: RecraftControls | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"),
|
||||
response_model=RecraftImageGenerationResponse,
|
||||
data=RecraftImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
model=model["model"],
|
||||
size=model["size"],
|
||||
n=n,
|
||||
style="vector_illustration",
|
||||
substyle=None,
|
||||
controls=recraft_controls.create_api_model() if recraft_controls else None,
|
||||
),
|
||||
max_retries=1,
|
||||
)
|
||||
svg_data = []
|
||||
for data in response.data:
|
||||
svg_data.append(await download_url_as_bytesio(data.url, timeout=1024))
|
||||
return IO.NodeOutput(SVG(svg_data))
|
||||
|
||||
|
||||
class RecraftExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@@ -1098,6 +1344,8 @@ class RecraftExtension(ComfyExtension):
|
||||
RecraftCreateStyleNode,
|
||||
RecraftColorRGBNode,
|
||||
RecraftControlsNode,
|
||||
RecraftV4TextToImageNode,
|
||||
RecraftV4TextToVectorNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -505,6 +505,9 @@ class Rodin3D_Gen2(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.4}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -622,7 +622,7 @@ class StabilityTextToAudio(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityTextToAudio",
|
||||
display_name="Stability AI Text To Audio",
|
||||
display_name="Music generation",
|
||||
category="api node/audio/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
|
||||
@@ -54,6 +54,7 @@ async def execute_task(
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.state,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
price_extractor=lambda r: r.credits * 0.005 if r.credits is not None else None,
|
||||
max_poll_attempts=max_poll_attempts,
|
||||
)
|
||||
if not response.creations:
|
||||
@@ -1306,6 +1307,36 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"viduq3-turbo",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "9:16", "3:4", "4:3", "1:1"],
|
||||
tooltip="The aspect ratio of the output video.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720p", "1080p"],
|
||||
tooltip="Resolution of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=1,
|
||||
max=16,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Duration of the output video in seconds.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"audio",
|
||||
default=False,
|
||||
tooltip="When enabled, outputs video with sound "
|
||||
"(including dialogue and sound effects).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for video generation.",
|
||||
),
|
||||
@@ -1334,13 +1365,20 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$base := $lookup({"720p": 0.075, "1080p": 0.1}, $res);
|
||||
$perSec := $lookup({"720p": 0.025, "1080p": 0.05}, $res);
|
||||
{"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)}
|
||||
$d := $lookup(widgets, "model.duration");
|
||||
$contains(widgets.model, "turbo")
|
||||
? (
|
||||
$rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
|
||||
{"type":"usd","usd": $rate * $d}
|
||||
)
|
||||
: (
|
||||
$rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res);
|
||||
{"type":"usd","usd": $rate * $d}
|
||||
)
|
||||
)
|
||||
""",
|
||||
),
|
||||
@@ -1409,6 +1447,31 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"viduq3-turbo",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720p", "1080p"],
|
||||
tooltip="Resolution of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=1,
|
||||
max=16,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Duration of the output video in seconds.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"audio",
|
||||
default=False,
|
||||
tooltip="When enabled, outputs video with sound "
|
||||
"(including dialogue and sound effects).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for video generation.",
|
||||
),
|
||||
@@ -1442,13 +1505,20 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$base := $lookup({"720p": 0.075, "1080p": 0.275, "2k": 0.35}, $res);
|
||||
$perSec := $lookup({"720p": 0.05, "1080p": 0.075, "2k": 0.075}, $res);
|
||||
{"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)}
|
||||
$d := $lookup(widgets, "model.duration");
|
||||
$contains(widgets.model, "turbo")
|
||||
? (
|
||||
$rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
|
||||
{"type":"usd","usd": $rate * $d}
|
||||
)
|
||||
: (
|
||||
$rate := $lookup({"720p": 0.15, "1080p": 0.16, "2k": 0.2}, $res);
|
||||
{"type":"usd","usd": $rate * $d}
|
||||
)
|
||||
)
|
||||
""",
|
||||
),
|
||||
@@ -1481,6 +1551,145 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class Vidu3StartEndToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Vidu3StartEndToVideoNode",
|
||||
display_name="Vidu Q3 Start/End Frame-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate a video from a start frame, an end frame, and a prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"viduq3-pro",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720p", "1080p"],
|
||||
tooltip="Resolution of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=1,
|
||||
max=16,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Duration of the output video in seconds.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"audio",
|
||||
default=False,
|
||||
tooltip="When enabled, outputs video with sound "
|
||||
"(including dialogue and sound effects).",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"viduq3-turbo",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720p", "1080p"],
|
||||
tooltip="Resolution of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=1,
|
||||
max=16,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Duration of the output video in seconds.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"audio",
|
||||
default=False,
|
||||
tooltip="When enabled, outputs video with sound "
|
||||
"(including dialogue and sound effects).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for video generation.",
|
||||
),
|
||||
IO.Image.Input("first_frame"),
|
||||
IO.Image.Input("end_frame"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Prompt description (max 2000 characters).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=1,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$d := $lookup(widgets, "model.duration");
|
||||
$contains(widgets.model, "turbo")
|
||||
? (
|
||||
$rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
|
||||
{"type":"usd","usd": $rate * $d}
|
||||
)
|
||||
: (
|
||||
$rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res);
|
||||
{"type":"usd","usd": $rate * $d}
|
||||
)
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: dict,
|
||||
first_frame: Input.Image,
|
||||
end_frame: Input.Image,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=2000)
|
||||
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
|
||||
payload = TaskCreationRequest(
|
||||
model=model["model"],
|
||||
prompt=prompt,
|
||||
duration=model["duration"],
|
||||
seed=seed,
|
||||
resolution=model["resolution"],
|
||||
audio=model["audio"],
|
||||
images=[
|
||||
(await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
|
||||
for frame in (first_frame, end_frame)
|
||||
],
|
||||
)
|
||||
results = await execute_task(cls, VIDU_START_END_VIDEO, payload)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class ViduExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@@ -1497,6 +1706,7 @@ class ViduExtension(ComfyExtension):
|
||||
ViduMultiFrameVideoNode,
|
||||
Vidu3TextToVideoNode,
|
||||
Vidu3ImageToVideoNode,
|
||||
Vidu3StartEndToVideoNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -157,7 +157,7 @@ class SaveAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="SaveAudio",
|
||||
search_aliases=["export flac"],
|
||||
display_name="Save Audio (FLAC)",
|
||||
display_name="Save Audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
@@ -698,6 +698,67 @@ class EmptyAudio(IO.ComfyNode):
|
||||
create_empty_audio = execute # TODO: remove
|
||||
|
||||
|
||||
class AudioEqualizer3Band(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="AudioEqualizer3Band",
|
||||
search_aliases=["eq", "bass boost", "treble boost", "equalizer"],
|
||||
display_name="Audio Equalizer (3-Band)",
|
||||
category="audio",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.Float.Input("low_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Low frequencies (Bass)"),
|
||||
IO.Int.Input("low_freq", default=100, min=20, max=500, tooltip="Cutoff frequency for Low shelf"),
|
||||
IO.Float.Input("mid_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Mid frequencies"),
|
||||
IO.Int.Input("mid_freq", default=1000, min=200, max=4000, tooltip="Center frequency for Mids"),
|
||||
IO.Float.Input("mid_q", default=0.707, min=0.1, max=10.0, step=0.1, tooltip="Q factor (bandwidth) for Mids"),
|
||||
IO.Float.Input("high_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for High frequencies (Treble)"),
|
||||
IO.Int.Input("high_freq", default=5000, min=1000, max=15000, tooltip="Cutoff frequency for High shelf"),
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, low_gain_dB, low_freq, mid_gain_dB, mid_freq, mid_q, high_gain_dB, high_freq) -> IO.NodeOutput:
|
||||
waveform = audio["waveform"]
|
||||
sample_rate = audio["sample_rate"]
|
||||
eq_waveform = waveform.clone()
|
||||
|
||||
# 1. Apply Low Shelf (Bass)
|
||||
if low_gain_dB != 0:
|
||||
eq_waveform = torchaudio.functional.bass_biquad(
|
||||
eq_waveform,
|
||||
sample_rate,
|
||||
gain=low_gain_dB,
|
||||
central_freq=float(low_freq),
|
||||
Q=0.707
|
||||
)
|
||||
|
||||
# 2. Apply Peaking EQ (Mids)
|
||||
if mid_gain_dB != 0:
|
||||
eq_waveform = torchaudio.functional.equalizer_biquad(
|
||||
eq_waveform,
|
||||
sample_rate,
|
||||
center_freq=float(mid_freq),
|
||||
gain=mid_gain_dB,
|
||||
Q=mid_q
|
||||
)
|
||||
|
||||
# 3. Apply High Shelf (Treble)
|
||||
if high_gain_dB != 0:
|
||||
eq_waveform = torchaudio.functional.treble_biquad(
|
||||
eq_waveform,
|
||||
sample_rate,
|
||||
gain=high_gain_dB,
|
||||
central_freq=float(high_freq),
|
||||
Q=0.707
|
||||
)
|
||||
|
||||
return IO.NodeOutput({"waveform": eq_waveform, "sample_rate": sample_rate})
|
||||
|
||||
|
||||
class AudioExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@@ -720,6 +781,7 @@ class AudioExtension(ComfyExtension):
|
||||
AudioMerge,
|
||||
AudioAdjustVolume,
|
||||
EmptyAudio,
|
||||
AudioEqualizer3Band,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> AudioExtension:
|
||||
|
||||
@@ -1,876 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import logging
|
||||
import ctypes.util
|
||||
import importlib.util
|
||||
from typing import TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import nodes
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from typing_extensions import override
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _check_opengl_availability():
|
||||
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
|
||||
logger.debug("_check_opengl_availability: starting")
|
||||
missing = []
|
||||
|
||||
# Check Python packages (using find_spec to avoid importing)
|
||||
logger.debug("_check_opengl_availability: checking for glfw package")
|
||||
if importlib.util.find_spec("glfw") is None:
|
||||
missing.append("glfw")
|
||||
|
||||
logger.debug("_check_opengl_availability: checking for OpenGL package")
|
||||
if importlib.util.find_spec("OpenGL") is None:
|
||||
missing.append("PyOpenGL")
|
||||
|
||||
if missing:
|
||||
raise RuntimeError(
|
||||
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
|
||||
)
|
||||
|
||||
# On Linux without display, check if headless backends are available
|
||||
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
|
||||
if sys.platform.startswith("linux"):
|
||||
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
|
||||
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
|
||||
if not has_display:
|
||||
# Check for EGL or OSMesa libraries
|
||||
logger.debug("_check_opengl_availability: checking for EGL library")
|
||||
has_egl = ctypes.util.find_library("EGL")
|
||||
logger.debug("_check_opengl_availability: checking for OSMesa library")
|
||||
has_osmesa = ctypes.util.find_library("OSMesa")
|
||||
|
||||
# Error disabled for CI as it fails this check
|
||||
# if not has_egl and not has_osmesa:
|
||||
# raise RuntimeError(
|
||||
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
|
||||
# "See error below for installation instructions."
|
||||
# )
|
||||
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
|
||||
|
||||
logger.debug("_check_opengl_availability: completed")
|
||||
|
||||
|
||||
# Run early check at import time
|
||||
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
|
||||
_check_opengl_availability()
|
||||
|
||||
# OpenGL modules - initialized lazily when context is created
|
||||
gl = None
|
||||
glfw = None
|
||||
EGL = None
|
||||
|
||||
|
||||
def _import_opengl():
|
||||
"""Import OpenGL module. Called after context is created."""
|
||||
global gl
|
||||
if gl is None:
|
||||
logger.debug("_import_opengl: importing OpenGL.GL")
|
||||
import OpenGL.GL as _gl
|
||||
gl = _gl
|
||||
logger.debug("_import_opengl: import completed")
|
||||
return gl
|
||||
|
||||
|
||||
class SizeModeInput(TypedDict):
|
||||
size_mode: str
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
MAX_IMAGES = 5 # u_image0-4
|
||||
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
|
||||
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||
|
||||
# Vertex shader using gl_VertexID trick - no VBO needed.
|
||||
# Draws a single triangle that covers the entire screen:
|
||||
#
|
||||
# (-1,3)
|
||||
# /|
|
||||
# / | <- visible area is the unit square from (-1,-1) to (1,1)
|
||||
# / | parts outside get clipped away
|
||||
# (-1,-1)---(3,-1)
|
||||
#
|
||||
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
|
||||
VERTEX_SHADER = """#version 330 core
|
||||
out vec2 v_texCoord;
|
||||
void main() {
|
||||
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
|
||||
v_texCoord = verts[gl_VertexID] * 0.5 + 0.5;
|
||||
gl_Position = vec4(verts[gl_VertexID], 0, 1);
|
||||
}
|
||||
"""
|
||||
|
||||
DEFAULT_FRAGMENT_SHADER = """#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
void main() {
|
||||
fragColor0 = texture(u_image0, v_texCoord);
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def _convert_es_to_desktop(source: str) -> str:
|
||||
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
|
||||
# Remove any existing #version directive
|
||||
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
|
||||
# Remove precision qualifiers (not needed in desktop GLSL)
|
||||
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
|
||||
# Prepend desktop GLSL version
|
||||
return "#version 330 core\n" + source
|
||||
|
||||
|
||||
def _detect_output_count(source: str) -> int:
|
||||
"""Detect how many fragColor outputs are used in the shader.
|
||||
|
||||
Returns the count of outputs needed (1 to MAX_OUTPUTS).
|
||||
"""
|
||||
matches = re.findall(r"fragColor(\d+)", source)
|
||||
if not matches:
|
||||
return 1 # Default to 1 output if none found
|
||||
max_index = max(int(m) for m in matches)
|
||||
return min(max_index + 1, MAX_OUTPUTS)
|
||||
|
||||
|
||||
def _init_glfw():
|
||||
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_glfw: starting")
|
||||
# On macOS, glfw.init() must be called from main thread or it hangs forever
|
||||
if sys.platform == "darwin":
|
||||
logger.debug("_init_glfw: skipping on macOS")
|
||||
raise RuntimeError("GLFW backend not supported on macOS")
|
||||
|
||||
logger.debug("_init_glfw: importing glfw module")
|
||||
import glfw as _glfw
|
||||
|
||||
logger.debug("_init_glfw: calling glfw.init()")
|
||||
if not _glfw.init():
|
||||
raise RuntimeError("glfw.init() failed")
|
||||
|
||||
try:
|
||||
logger.debug("_init_glfw: setting window hints")
|
||||
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
|
||||
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
|
||||
|
||||
logger.debug("_init_glfw: calling create_window()")
|
||||
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
|
||||
if not window:
|
||||
raise RuntimeError("glfw.create_window() failed")
|
||||
|
||||
logger.debug("_init_glfw: calling make_context_current()")
|
||||
_glfw.make_context_current(window)
|
||||
logger.debug("_init_glfw: completed successfully")
|
||||
return window, _glfw
|
||||
except Exception:
|
||||
logger.debug("_init_glfw: failed, terminating glfw")
|
||||
_glfw.terminate()
|
||||
raise
|
||||
|
||||
|
||||
def _init_egl():
|
||||
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_egl: starting")
|
||||
from OpenGL import EGL as _EGL
|
||||
from OpenGL.EGL import (
|
||||
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
|
||||
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
|
||||
eglTerminate, eglDestroyContext, eglDestroySurface,
|
||||
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
|
||||
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
|
||||
)
|
||||
logger.debug("_init_egl: imports completed")
|
||||
|
||||
display = None
|
||||
context = None
|
||||
surface = None
|
||||
|
||||
try:
|
||||
logger.debug("_init_egl: calling eglGetDisplay()")
|
||||
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
|
||||
if display == _EGL.EGL_NO_DISPLAY:
|
||||
raise RuntimeError("eglGetDisplay() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglInitialize()")
|
||||
major, minor = _EGL.EGLint(), _EGL.EGLint()
|
||||
if not eglInitialize(display, major, minor):
|
||||
display = None # Not initialized, don't terminate
|
||||
raise RuntimeError("eglInitialize() failed")
|
||||
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
|
||||
|
||||
config_attribs = [
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
|
||||
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
|
||||
EGL_DEPTH_SIZE, 0, EGL_NONE
|
||||
]
|
||||
configs = (_EGL.EGLConfig * 1)()
|
||||
num_configs = _EGL.EGLint()
|
||||
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
config = configs[0]
|
||||
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
|
||||
|
||||
if not eglBindAPI(EGL_OPENGL_API):
|
||||
raise RuntimeError("eglBindAPI() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreateContext()")
|
||||
context_attribs = [
|
||||
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
|
||||
EGL_NONE
|
||||
]
|
||||
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
|
||||
if context == EGL_NO_CONTEXT:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
|
||||
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
|
||||
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
|
||||
if surface == _EGL.EGL_NO_SURFACE:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglMakeCurrent()")
|
||||
if not eglMakeCurrent(display, surface, surface, context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_egl: completed successfully")
|
||||
return display, context, surface, _EGL
|
||||
|
||||
except Exception:
|
||||
logger.debug("_init_egl: failed, cleaning up")
|
||||
# Clean up any resources on failure
|
||||
if surface is not None:
|
||||
eglDestroySurface(display, surface)
|
||||
if context is not None:
|
||||
eglDestroyContext(display, context)
|
||||
if display is not None:
|
||||
eglTerminate(display)
|
||||
raise
|
||||
|
||||
|
||||
def _init_osmesa():
|
||||
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
|
||||
import ctypes
|
||||
|
||||
logger.debug("_init_osmesa: starting")
|
||||
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
||||
|
||||
logger.debug("_init_osmesa: importing OpenGL.osmesa")
|
||||
from OpenGL import GL as _gl
|
||||
from OpenGL.osmesa import (
|
||||
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
|
||||
OSMESA_RGBA,
|
||||
)
|
||||
logger.debug("_init_osmesa: imports completed")
|
||||
|
||||
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
|
||||
if not ctx:
|
||||
raise RuntimeError("OSMesaCreateContextExt() failed")
|
||||
|
||||
width, height = 64, 64
|
||||
buffer = (ctypes.c_ubyte * (width * height * 4))()
|
||||
|
||||
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
|
||||
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
|
||||
OSMesaDestroyContext(ctx)
|
||||
raise RuntimeError("OSMesaMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_osmesa: completed successfully")
|
||||
return ctx, buffer
|
||||
|
||||
def _init_cgl():
|
||||
"""Initialize CGL (macOS native OpenGL). Returns (cgl_context, opengl_lib). Raises RuntimeError on failure."""
|
||||
import ctypes
|
||||
import ctypes.util
|
||||
|
||||
logger.debug("_init_cgl: starting")
|
||||
|
||||
opengl_path = ctypes.util.find_library("OpenGL")
|
||||
if not opengl_path:
|
||||
raise RuntimeError("Could not find OpenGL framework")
|
||||
opengl = ctypes.cdll.LoadLibrary(opengl_path)
|
||||
|
||||
CGLPixelFormatObj = ctypes.c_void_p
|
||||
CGLContextObj = ctypes.c_void_p
|
||||
|
||||
kCGLPFAOpenGLProfile = 99
|
||||
kCGLOGLPVersion_3_2_Core = 0x3200
|
||||
kCGLPFAAccelerated = 73
|
||||
kCGLPFAColorSize = 8
|
||||
kCGLPFAAllowOfflineRenderers = 96
|
||||
|
||||
attrs = (ctypes.c_int * 7)(
|
||||
kCGLPFAOpenGLProfile, kCGLOGLPVersion_3_2_Core,
|
||||
kCGLPFAAccelerated,
|
||||
kCGLPFAColorSize, 32,
|
||||
kCGLPFAAllowOfflineRenderers,
|
||||
0,
|
||||
)
|
||||
|
||||
pix_fmt = CGLPixelFormatObj()
|
||||
npix = ctypes.c_int(0)
|
||||
|
||||
err = opengl.CGLChoosePixelFormat(attrs, ctypes.byref(pix_fmt), ctypes.byref(npix))
|
||||
if err != 0 or not pix_fmt:
|
||||
raise RuntimeError(f"CGLChoosePixelFormat() failed with error {err}")
|
||||
|
||||
ctx = CGLContextObj()
|
||||
err = opengl.CGLCreateContext(pix_fmt, None, ctypes.byref(ctx))
|
||||
opengl.CGLDestroyPixelFormat(pix_fmt)
|
||||
if err != 0 or not ctx:
|
||||
raise RuntimeError(f"CGLCreateContext() failed with error {err}")
|
||||
|
||||
err = opengl.CGLSetCurrentContext(ctx)
|
||||
if err != 0:
|
||||
opengl.CGLDestroyContext(ctx)
|
||||
raise RuntimeError(f"CGLSetCurrentContext() failed with error {err}")
|
||||
|
||||
logger.debug("_init_cgl: completed successfully")
|
||||
return ctx, opengl
|
||||
|
||||
|
||||
class GLContext:
|
||||
"""Manages OpenGL context and resources for shader execution.
|
||||
|
||||
Tries backends in order: GLFW (desktop) → CGL (macOS) → EGL (headless GPU) → OSMesa (software).
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if GLContext._initialized:
|
||||
logger.debug("GLContext.__init__: already initialized, skipping")
|
||||
return
|
||||
|
||||
logger.debug("GLContext.__init__: starting initialization")
|
||||
|
||||
global glfw, EGL
|
||||
|
||||
import time
|
||||
start = time.perf_counter()
|
||||
|
||||
self._backend = None
|
||||
self._window = None
|
||||
self._cgl_ctx = None
|
||||
self._cgl_lib = None
|
||||
self._egl_display = None
|
||||
self._egl_context = None
|
||||
self._egl_surface = None
|
||||
self._osmesa_ctx = None
|
||||
self._osmesa_buffer = None
|
||||
self._vao = None
|
||||
|
||||
# Try backends in order: GLFW → CGL (macOS) → EGL (non-macOS) → OSMesa
|
||||
errors = []
|
||||
|
||||
logger.debug("GLContext.__init__: trying GLFW backend")
|
||||
try:
|
||||
self._window, glfw = _init_glfw()
|
||||
self._backend = "glfw"
|
||||
logger.debug("GLContext.__init__: GLFW backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
|
||||
errors.append(("GLFW", e))
|
||||
|
||||
if self._backend is None and sys.platform == "darwin":
|
||||
logger.debug("GLContext.__init__: trying CGL backend")
|
||||
try:
|
||||
self._cgl_ctx, self._cgl_lib = _init_cgl()
|
||||
self._backend = "cgl"
|
||||
logger.debug("GLContext.__init__: CGL backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: CGL backend failed: {e}")
|
||||
errors.append(("CGL", e))
|
||||
|
||||
# Skip EGL on macOS — DarwinPlatform doesn't support EGL, and importing
|
||||
# it poisons PyOpenGL's platform selection, preventing OSMesa from working.
|
||||
if self._backend is None and sys.platform != "darwin":
|
||||
logger.debug("GLContext.__init__: trying EGL backend")
|
||||
try:
|
||||
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
|
||||
self._backend = "egl"
|
||||
logger.debug("GLContext.__init__: EGL backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
|
||||
errors.append(("EGL", e))
|
||||
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying OSMesa backend")
|
||||
try:
|
||||
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
|
||||
self._backend = "osmesa"
|
||||
logger.debug("GLContext.__init__: OSMesa backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
|
||||
errors.append(("OSMesa", e))
|
||||
|
||||
if self._backend is None:
|
||||
if sys.platform == "win32":
|
||||
platform_help = (
|
||||
"Windows: Ensure GPU drivers are installed and display is available.\n"
|
||||
" CPU-only/headless mode is not supported on Windows."
|
||||
)
|
||||
elif sys.platform == "darwin":
|
||||
platform_help = (
|
||||
"macOS: CGL context creation failed.\n"
|
||||
" Ensure macOS OpenGL framework is available.\n"
|
||||
" Requires: pip install PyOpenGL PyOpenGL-accelerate"
|
||||
)
|
||||
else:
|
||||
platform_help = (
|
||||
"Linux: Install one of these backends:\n"
|
||||
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
|
||||
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
|
||||
" Headless (CPU): sudo apt install libosmesa6"
|
||||
)
|
||||
|
||||
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
|
||||
raise RuntimeError(
|
||||
f"Failed to create OpenGL context.\n\n"
|
||||
f"Backend errors:\n{error_details}\n\n"
|
||||
f"{platform_help}"
|
||||
)
|
||||
|
||||
# Now import OpenGL.GL (after context is current)
|
||||
logger.debug("GLContext.__init__: importing OpenGL.GL")
|
||||
_import_opengl()
|
||||
|
||||
# Create VAO (required for core profile, but OSMesa may use compat profile)
|
||||
logger.debug("GLContext.__init__: creating VAO")
|
||||
try:
|
||||
vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(vao)
|
||||
self._vao = vao # Only store after successful bind
|
||||
logger.debug("GLContext.__init__: VAO created successfully")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
|
||||
# OSMesa with older Mesa may not support VAOs
|
||||
# Clean up if we created but couldn't bind
|
||||
if vao:
|
||||
try:
|
||||
gl.glDeleteVertexArrays(1, [vao])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
elapsed = (time.perf_counter() - start) * 1000
|
||||
|
||||
# Log device info
|
||||
renderer = gl.glGetString(gl.GL_RENDERER)
|
||||
vendor = gl.glGetString(gl.GL_VENDOR)
|
||||
version = gl.glGetString(gl.GL_VERSION)
|
||||
renderer = renderer.decode() if renderer else "Unknown"
|
||||
vendor = vendor.decode() if vendor else "Unknown"
|
||||
version = version.decode() if version else "Unknown"
|
||||
|
||||
GLContext._initialized = True
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
|
||||
|
||||
def make_current(self):
|
||||
if self._backend == "glfw":
|
||||
glfw.make_context_current(self._window)
|
||||
elif self._backend == "cgl":
|
||||
self._cgl_lib.CGLSetCurrentContext(self._cgl_ctx)
|
||||
elif self._backend == "egl":
|
||||
from OpenGL.EGL import eglMakeCurrent
|
||||
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
|
||||
elif self._backend == "osmesa":
|
||||
from OpenGL.osmesa import OSMesaMakeCurrent
|
||||
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
|
||||
|
||||
if self._vao is not None:
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
|
||||
def _compile_shader(source: str, shader_type: int) -> int:
|
||||
"""Compile a shader and return its ID."""
|
||||
shader = gl.glCreateShader(shader_type)
|
||||
gl.glShaderSource(shader, source)
|
||||
gl.glCompileShader(shader)
|
||||
|
||||
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetShaderInfoLog(shader).decode()
|
||||
gl.glDeleteShader(shader)
|
||||
raise RuntimeError(f"Shader compilation failed:\n{error}")
|
||||
|
||||
return shader
|
||||
|
||||
|
||||
def _create_program(vertex_source: str, fragment_source: str) -> int:
|
||||
"""Create and link a shader program."""
|
||||
vertex_shader = _compile_shader(vertex_source, gl.GL_VERTEX_SHADER)
|
||||
try:
|
||||
fragment_shader = _compile_shader(fragment_source, gl.GL_FRAGMENT_SHADER)
|
||||
except RuntimeError:
|
||||
gl.glDeleteShader(vertex_shader)
|
||||
raise
|
||||
|
||||
program = gl.glCreateProgram()
|
||||
gl.glAttachShader(program, vertex_shader)
|
||||
gl.glAttachShader(program, fragment_shader)
|
||||
gl.glLinkProgram(program)
|
||||
|
||||
gl.glDeleteShader(vertex_shader)
|
||||
gl.glDeleteShader(fragment_shader)
|
||||
|
||||
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetProgramInfoLog(program).decode()
|
||||
gl.glDeleteProgram(program)
|
||||
raise RuntimeError(f"Program linking failed:\n{error}")
|
||||
|
||||
return program
|
||||
|
||||
|
||||
def _render_shader_batch(
|
||||
fragment_code: str,
|
||||
width: int,
|
||||
height: int,
|
||||
image_batches: list[list[np.ndarray]],
|
||||
floats: list[float],
|
||||
ints: list[int],
|
||||
) -> list[list[np.ndarray]]:
|
||||
"""
|
||||
Render a fragment shader for multiple batches efficiently.
|
||||
|
||||
Compiles shader once, reuses framebuffer/textures across batches.
|
||||
|
||||
Args:
|
||||
fragment_code: User's fragment shader code
|
||||
width: Output width
|
||||
height: Output height
|
||||
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
|
||||
floats: List of float uniforms
|
||||
ints: List of int uniforms
|
||||
|
||||
Returns:
|
||||
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
|
||||
"""
|
||||
if not image_batches:
|
||||
return []
|
||||
|
||||
ctx = GLContext()
|
||||
ctx.make_current()
|
||||
|
||||
# Convert from GLSL ES to desktop GLSL 330
|
||||
fragment_source = _convert_es_to_desktop(fragment_code)
|
||||
|
||||
# Detect how many outputs the shader actually uses
|
||||
num_outputs = _detect_output_count(fragment_code)
|
||||
|
||||
# Track resources for cleanup
|
||||
program = None
|
||||
fbo = None
|
||||
output_textures = []
|
||||
input_textures = []
|
||||
|
||||
num_inputs = len(image_batches[0])
|
||||
|
||||
try:
|
||||
# Compile shaders (once for all batches)
|
||||
try:
|
||||
program = _create_program(VERTEX_SHADER, fragment_source)
|
||||
except RuntimeError:
|
||||
logger.error(f"Fragment shader:\n{fragment_source}")
|
||||
raise
|
||||
|
||||
gl.glUseProgram(program)
|
||||
|
||||
# Create framebuffer with only the needed color attachments
|
||||
fbo = gl.glGenFramebuffers(1)
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
||||
|
||||
draw_buffers = []
|
||||
for i in range(num_outputs):
|
||||
tex = gl.glGenTextures(1)
|
||||
output_textures.append(tex)
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, width, height, 0, gl.GL_RGBA, gl.GL_FLOAT, None)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
||||
gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0 + i, gl.GL_TEXTURE_2D, tex, 0)
|
||||
draw_buffers.append(gl.GL_COLOR_ATTACHMENT0 + i)
|
||||
|
||||
gl.glDrawBuffers(num_outputs, draw_buffers)
|
||||
|
||||
if gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) != gl.GL_FRAMEBUFFER_COMPLETE:
|
||||
raise RuntimeError("Framebuffer is not complete")
|
||||
|
||||
# Create input textures (reused for all batches)
|
||||
for i in range(num_inputs):
|
||||
tex = gl.glGenTextures(1)
|
||||
input_textures.append(tex)
|
||||
gl.glActiveTexture(gl.GL_TEXTURE0 + i)
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
|
||||
|
||||
loc = gl.glGetUniformLocation(program, f"u_image{i}")
|
||||
if loc >= 0:
|
||||
gl.glUniform1i(loc, i)
|
||||
|
||||
# Set static uniforms (once for all batches)
|
||||
loc = gl.glGetUniformLocation(program, "u_resolution")
|
||||
if loc >= 0:
|
||||
gl.glUniform2f(loc, float(width), float(height))
|
||||
|
||||
for i, v in enumerate(floats):
|
||||
loc = gl.glGetUniformLocation(program, f"u_float{i}")
|
||||
if loc >= 0:
|
||||
gl.glUniform1f(loc, v)
|
||||
|
||||
for i, v in enumerate(ints):
|
||||
loc = gl.glGetUniformLocation(program, f"u_int{i}")
|
||||
if loc >= 0:
|
||||
gl.glUniform1i(loc, v)
|
||||
|
||||
gl.glViewport(0, 0, width, height)
|
||||
gl.glDisable(gl.GL_BLEND) # Ensure no alpha blending - write output directly
|
||||
|
||||
# Process each batch
|
||||
all_batch_outputs = []
|
||||
for images in image_batches:
|
||||
# Update input textures with this batch's images
|
||||
for i, img in enumerate(images):
|
||||
gl.glActiveTexture(gl.GL_TEXTURE0 + i)
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, input_textures[i])
|
||||
|
||||
# Flip vertically for GL coordinates, ensure RGBA
|
||||
h, w, c = img.shape
|
||||
if c == 3:
|
||||
img_upload = np.empty((h, w, 4), dtype=np.float32)
|
||||
img_upload[:, :, :3] = img[::-1, :, :]
|
||||
img_upload[:, :, 3] = 1.0
|
||||
else:
|
||||
img_upload = np.ascontiguousarray(img[::-1, :, :])
|
||||
|
||||
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, w, h, 0, gl.GL_RGBA, gl.GL_FLOAT, img_upload)
|
||||
|
||||
# Render
|
||||
gl.glClearColor(0, 0, 0, 0)
|
||||
gl.glClear(gl.GL_COLOR_BUFFER_BIT)
|
||||
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
|
||||
|
||||
# Read back outputs for this batch
|
||||
# (glGetTexImage is synchronous, implicitly waits for rendering)
|
||||
batch_outputs = []
|
||||
for tex in output_textures:
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
|
||||
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
|
||||
batch_outputs.append(np.ascontiguousarray(img[::-1, :, :]))
|
||||
|
||||
# Pad with black images for unused outputs
|
||||
black_img = np.zeros((height, width, 4), dtype=np.float32)
|
||||
for _ in range(num_outputs, MAX_OUTPUTS):
|
||||
batch_outputs.append(black_img)
|
||||
|
||||
all_batch_outputs.append(batch_outputs)
|
||||
|
||||
return all_batch_outputs
|
||||
|
||||
finally:
|
||||
# Unbind before deleting
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
|
||||
gl.glUseProgram(0)
|
||||
|
||||
if input_textures:
|
||||
gl.glDeleteTextures(len(input_textures), input_textures)
|
||||
if output_textures:
|
||||
gl.glDeleteTextures(len(output_textures), output_textures)
|
||||
if fbo is not None:
|
||||
gl.glDeleteFramebuffers(1, [fbo])
|
||||
if program is not None:
|
||||
gl.glDeleteProgram(program)
|
||||
|
||||
class GLSLShader(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
image_template = io.Autogrow.TemplatePrefix(
|
||||
io.Image.Input("image"),
|
||||
prefix="image",
|
||||
min=1,
|
||||
max=MAX_IMAGES,
|
||||
)
|
||||
|
||||
float_template = io.Autogrow.TemplatePrefix(
|
||||
io.Float.Input("float", default=0.0),
|
||||
prefix="u_float",
|
||||
min=0,
|
||||
max=MAX_UNIFORMS,
|
||||
)
|
||||
|
||||
int_template = io.Autogrow.TemplatePrefix(
|
||||
io.Int.Input("int", default=0),
|
||||
prefix="u_int",
|
||||
min=0,
|
||||
max=MAX_UNIFORMS,
|
||||
)
|
||||
|
||||
return io.Schema(
|
||||
node_id="GLSLShader",
|
||||
display_name="GLSL Shader",
|
||||
category="image/shader",
|
||||
description=(
|
||||
f"Apply GLSL fragment shaders to images. "
|
||||
f"Inputs: u_image0-{MAX_IMAGES-1} (sampler2D), u_resolution (vec2), "
|
||||
f"u_float0-{MAX_UNIFORMS-1}, u_int0-{MAX_UNIFORMS-1}. "
|
||||
f"Outputs: layout(location = 0-{MAX_OUTPUTS-1}) out vec4 fragColor0-{MAX_OUTPUTS-1}."
|
||||
),
|
||||
inputs=[
|
||||
io.String.Input(
|
||||
"fragment_shader",
|
||||
default=DEFAULT_FRAGMENT_SHADER,
|
||||
multiline=True,
|
||||
tooltip="GLSL fragment shader source code (GLSL ES 3.00 / WebGL 2.0 compatible)",
|
||||
),
|
||||
io.DynamicCombo.Input(
|
||||
"size_mode",
|
||||
options=[
|
||||
io.DynamicCombo.Option("from_input", []),
|
||||
io.DynamicCombo.Option(
|
||||
"custom",
|
||||
[
|
||||
io.Int.Input(
|
||||
"width",
|
||||
default=512,
|
||||
min=1,
|
||||
max=nodes.MAX_RESOLUTION,
|
||||
),
|
||||
io.Int.Input(
|
||||
"height",
|
||||
default=512,
|
||||
min=1,
|
||||
max=nodes.MAX_RESOLUTION,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Output size: 'from_input' uses first input image dimensions, 'custom' allows manual size",
|
||||
),
|
||||
io.Autogrow.Input("images", template=image_template),
|
||||
io.Autogrow.Input("floats", template=float_template),
|
||||
io.Autogrow.Input("ints", template=int_template),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="IMAGE0"),
|
||||
io.Image.Output(display_name="IMAGE1"),
|
||||
io.Image.Output(display_name="IMAGE2"),
|
||||
io.Image.Output(display_name="IMAGE3"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
fragment_shader: str,
|
||||
size_mode: SizeModeInput,
|
||||
images: io.Autogrow.Type,
|
||||
floats: io.Autogrow.Type = None,
|
||||
ints: io.Autogrow.Type = None,
|
||||
**kwargs,
|
||||
) -> io.NodeOutput:
|
||||
image_list = [v for v in images.values() if v is not None]
|
||||
float_list = (
|
||||
[v if v is not None else 0.0 for v in floats.values()] if floats else []
|
||||
)
|
||||
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
|
||||
|
||||
if not image_list:
|
||||
raise ValueError("At least one input image is required")
|
||||
|
||||
# Determine output dimensions
|
||||
if size_mode["size_mode"] == "custom":
|
||||
out_width = size_mode["width"]
|
||||
out_height = size_mode["height"]
|
||||
else:
|
||||
out_height, out_width = image_list[0].shape[1:3]
|
||||
|
||||
batch_size = image_list[0].shape[0]
|
||||
|
||||
# Prepare batches
|
||||
image_batches = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_images = [img_tensor[batch_idx].cpu().numpy().astype(np.float32) for img_tensor in image_list]
|
||||
image_batches.append(batch_images)
|
||||
|
||||
all_batch_outputs = _render_shader_batch(
|
||||
fragment_shader,
|
||||
out_width,
|
||||
out_height,
|
||||
image_batches,
|
||||
float_list,
|
||||
int_list,
|
||||
)
|
||||
|
||||
# Collect outputs into tensors
|
||||
all_outputs = [[] for _ in range(MAX_OUTPUTS)]
|
||||
for batch_outputs in all_batch_outputs:
|
||||
for i, out_img in enumerate(batch_outputs):
|
||||
all_outputs[i].append(torch.from_numpy(out_img))
|
||||
|
||||
output_tensors = [torch.stack(all_outputs[i], dim=0) for i in range(MAX_OUTPUTS)]
|
||||
return io.NodeOutput(
|
||||
*output_tensors,
|
||||
ui=cls._build_ui_output(image_list, output_tensors[0]),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _build_ui_output(
|
||||
cls, image_list: list[torch.Tensor], output_batch: torch.Tensor
|
||||
) -> dict[str, list]:
|
||||
"""Build UI output with input and output images for client-side shader execution."""
|
||||
combined_inputs = torch.cat(image_list, dim=0)
|
||||
input_images_ui = ui.ImageSaveHelper.save_images(
|
||||
combined_inputs,
|
||||
filename_prefix="GLSLShader_input",
|
||||
folder_type=io.FolderType.temp,
|
||||
cls=None,
|
||||
compress_level=1,
|
||||
)
|
||||
|
||||
output_images_ui = ui.ImageSaveHelper.save_images(
|
||||
output_batch,
|
||||
filename_prefix="GLSLShader_output",
|
||||
folder_type=io.FolderType.temp,
|
||||
cls=None,
|
||||
compress_level=1,
|
||||
)
|
||||
|
||||
return {"input_images": input_images_ui, "images": output_images_ui}
|
||||
|
||||
|
||||
class GLSLExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [GLSLShader]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> GLSLExtension:
|
||||
return GLSLExtension()
|
||||
@@ -23,8 +23,9 @@ class ImageCrop(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ImageCrop",
|
||||
search_aliases=["trim"],
|
||||
display_name="Image Crop",
|
||||
display_name="Crop Image (Deprecated)",
|
||||
category="image/transform",
|
||||
is_deprecated=True,
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
@@ -47,6 +48,57 @@ class ImageCrop(IO.ComfyNode):
|
||||
crop = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageCropV2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageCropV2",
|
||||
search_aliases=["trim"],
|
||||
display_name="Image Crop",
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.BoundingBox.Input("crop_region", component="ImageCrop"),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image, crop_region) -> IO.NodeOutput:
|
||||
x = crop_region.get("x", 0)
|
||||
y = crop_region.get("y", 0)
|
||||
width = crop_region.get("width", 512)
|
||||
height = crop_region.get("height", 512)
|
||||
|
||||
x = min(x, image.shape[2] - 1)
|
||||
y = min(y, image.shape[1] - 1)
|
||||
to_x = width + x
|
||||
to_y = height + y
|
||||
img = image[:,y:to_y, x:to_x, :]
|
||||
return IO.NodeOutput(img, ui=UI.PreviewImage(img))
|
||||
|
||||
|
||||
class BoundingBox(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="PrimitiveBoundingBox",
|
||||
display_name="Bounding Box",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
IO.Int.Input("x", default=0, min=0, max=MAX_RESOLUTION),
|
||||
IO.Int.Input("y", default=0, min=0, max=MAX_RESOLUTION),
|
||||
IO.Int.Input("width", default=512, min=1, max=MAX_RESOLUTION),
|
||||
IO.Int.Input("height", default=512, min=1, max=MAX_RESOLUTION),
|
||||
],
|
||||
outputs=[IO.BoundingBox.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, x, y, width, height) -> IO.NodeOutput:
|
||||
return IO.NodeOutput({"x": x, "y": y, "width": width, "height": height})
|
||||
|
||||
|
||||
class RepeatImageBatch(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@@ -535,6 +587,7 @@ class ImageRotate(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageRotate",
|
||||
display_name="Rotate",
|
||||
search_aliases=["turn", "flip orientation"],
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
@@ -632,6 +685,8 @@ class ImagesExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
ImageCrop,
|
||||
ImageCropV2,
|
||||
BoundingBox,
|
||||
RepeatImageBatch,
|
||||
ImageFromBatch,
|
||||
ImageAddNoise,
|
||||
|
||||
@@ -29,7 +29,7 @@ class Load3D(IO.ComfyNode):
|
||||
]
|
||||
return IO.Schema(
|
||||
node_id="Load3D",
|
||||
display_name="Load 3D & Animation",
|
||||
display_name="Load 3D model",
|
||||
category="3d",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
|
||||
99
comfy_extras/nodes_nag.py
Normal file
99
comfy_extras/nodes_nag.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import torch
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class NAGuidance(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="NAGuidance",
|
||||
display_name="Normalized Attention Guidance",
|
||||
description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.",
|
||||
category="",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to apply NAG to."),
|
||||
io.Float.Input("nag_scale", min=0.0, default=5.0, max=50.0, step=0.1, tooltip="The guidance scale factor. Higher values push further from the negative prompt."),
|
||||
io.Float.Input("nag_alpha", min=0.0, default=0.5, max=1.0, step=0.01, tooltip="Blending factor for the normalized attention. 1.0 is full replacement, 0.0 is no effect."),
|
||||
io.Float.Input("nag_tau", min=1.0, default=1.5, max=10.0, step=0.01),
|
||||
# io.Float.Input("start_percent", min=0.0, default=0.0, max=1.0, step=0.01, tooltip="The relative sampling step to begin applying NAG."),
|
||||
# io.Float.Input("end_percent", min=0.0, default=1.0, max=1.0, step=0.01, tooltip="The relative sampling step to stop applying NAG."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The patched model with NAG enabled."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, nag_scale: float, nag_alpha: float, nag_tau: float) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
|
||||
# sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_percent)
|
||||
# sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_percent)
|
||||
|
||||
def nag_attention_output_patch(out, extra_options):
|
||||
cond_or_uncond = extra_options.get("cond_or_uncond", None)
|
||||
if cond_or_uncond is None:
|
||||
return out
|
||||
|
||||
if not (1 in cond_or_uncond and 0 in cond_or_uncond):
|
||||
return out
|
||||
|
||||
# sigma = extra_options.get("sigmas", None)
|
||||
# if sigma is not None and len(sigma) > 0:
|
||||
# sigma = sigma[0].item()
|
||||
# if sigma > sigma_start or sigma < sigma_end:
|
||||
# return out
|
||||
|
||||
img_slice = extra_options.get("img_slice", None)
|
||||
|
||||
if img_slice is not None:
|
||||
orig_out = out
|
||||
out = out[:, img_slice[0]:img_slice[1]] # only apply on img part
|
||||
|
||||
batch_size = out.shape[0]
|
||||
half_size = batch_size // len(cond_or_uncond)
|
||||
|
||||
ind_neg = cond_or_uncond.index(1)
|
||||
ind_pos = cond_or_uncond.index(0)
|
||||
z_pos = out[half_size * ind_pos:half_size * (ind_pos + 1)]
|
||||
z_neg = out[half_size * ind_neg:half_size * (ind_neg + 1)]
|
||||
|
||||
guided = z_pos * nag_scale - z_neg * (nag_scale - 1.0)
|
||||
|
||||
eps = 1e-6
|
||||
norm_pos = torch.norm(z_pos, p=1, dim=-1, keepdim=True).clamp_min(eps)
|
||||
norm_guided = torch.norm(guided, p=1, dim=-1, keepdim=True).clamp_min(eps)
|
||||
|
||||
ratio = norm_guided / norm_pos
|
||||
scale_factor = torch.minimum(ratio, torch.full_like(ratio, nag_tau)) / ratio
|
||||
|
||||
guided_normalized = guided * scale_factor
|
||||
|
||||
z_final = guided_normalized * nag_alpha + z_pos * (1.0 - nag_alpha)
|
||||
|
||||
if img_slice is not None:
|
||||
orig_out[half_size * ind_neg:half_size * (ind_neg + 1), img_slice[0]:img_slice[1]] = z_final
|
||||
orig_out[half_size * ind_pos:half_size * (ind_pos + 1), img_slice[0]:img_slice[1]] = z_final
|
||||
return orig_out
|
||||
else:
|
||||
out[half_size * ind_pos:half_size * (ind_pos + 1)] = z_final
|
||||
return out
|
||||
|
||||
m.set_model_attn1_output_patch(nag_attention_output_patch)
|
||||
m.disable_model_cfg1_optimization()
|
||||
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class NagExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
NAGuidance,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> NagExtension:
|
||||
return NagExtension()
|
||||
176
comfy_extras/nodes_textgen.py
Normal file
176
comfy_extras/nodes_textgen.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from typing_extensions import override
|
||||
|
||||
class TextGenerate(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
# Define dynamic combo options for sampling mode
|
||||
sampling_options = [
|
||||
io.DynamicCombo.Option(
|
||||
key="on",
|
||||
inputs=[
|
||||
io.Float.Input("temperature", default=0.7, min=0.01, max=2.0, step=0.000001),
|
||||
io.Int.Input("top_k", default=64, min=0, max=1000),
|
||||
io.Float.Input("top_p", default=0.95, min=0.0, max=1.0, step=0.01),
|
||||
io.Float.Input("min_p", default=0.05, min=0.0, max=1.0, step=0.01),
|
||||
io.Float.Input("repetition_penalty", default=1.05, min=0.0, max=5.0, step=0.01),
|
||||
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
|
||||
]
|
||||
),
|
||||
io.DynamicCombo.Option(
|
||||
key="off",
|
||||
inputs=[]
|
||||
),
|
||||
]
|
||||
|
||||
return io.Schema(
|
||||
node_id="TextGenerate",
|
||||
category="textgen/",
|
||||
search_aliases=["LLM", "gemma"],
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
|
||||
io.Image.Input("image", optional=True),
|
||||
io.Int.Input("max_length", default=256, min=1, max=2048),
|
||||
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(display_name="generated_text"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
|
||||
|
||||
tokens = clip.tokenize(prompt, image=image, skip_template=False)
|
||||
|
||||
# Get sampling parameters from dynamic combo
|
||||
do_sample = sampling_mode.get("sampling_mode") == "on"
|
||||
temperature = sampling_mode.get("temperature", 1.0)
|
||||
top_k = sampling_mode.get("top_k", 50)
|
||||
top_p = sampling_mode.get("top_p", 1.0)
|
||||
min_p = sampling_mode.get("min_p", 0.0)
|
||||
seed = sampling_mode.get("seed", None)
|
||||
repetition_penalty = sampling_mode.get("repetition_penalty", 1.0)
|
||||
|
||||
generated_ids = clip.generate(
|
||||
tokens,
|
||||
do_sample=do_sample,
|
||||
max_length=max_length,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
min_p=min_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
seed=seed
|
||||
)
|
||||
|
||||
generated_text = clip.decode(generated_ids, skip_special_tokens=True)
|
||||
return io.NodeOutput(generated_text)
|
||||
|
||||
|
||||
LTX2_T2V_SYSTEM_PROMPT = """You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed video generation prompt with specific visuals and integrated audio to guide a text-to-video model.
|
||||
#### Guidelines
|
||||
- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions, actions, camera movement, audio).
|
||||
- If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc.
|
||||
- For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters.
|
||||
- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
|
||||
- Maintain chronological flow: use temporal connectors ("as," "then," "while").
|
||||
- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested). Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g., "ambient sound is present").
|
||||
- Speech (only when requested):
|
||||
- For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'").
|
||||
- Specify language if not English and accent if relevant.
|
||||
- Style: Include visual style at the beginning: "Style: <style>, <rest of prompt>." Default to cinematic-realistic if unspecified. Omit if unclear.
|
||||
- Visual and audio only: NO non-visual/auditory senses (smell, taste, touch).
|
||||
- Restrained language: Avoid dramatic/exaggerated terms. Use mild, natural phrasing.
|
||||
- Colors: Use plain terms ("red dress"), not intensified ("vibrant blue," "bright red").
|
||||
- Lighting: Use neutral descriptions ("soft overhead light"), not harsh ("blinding light").
|
||||
- Facial features: Use delicate modifiers for subtle features (i.e., "subtle freckles").
|
||||
|
||||
#### Important notes:
|
||||
- Analyze the user's raw input carefully. In cases of FPV or POV, exclude the description of the subject whose POV is requested.
|
||||
- Camera motion: DO NOT invent camera motion unless requested by the user.
|
||||
- Speech: DO NOT modify user-provided character dialogue unless it's a typo.
|
||||
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
|
||||
- Format: DO NOT use phrases like "The scene opens with...". Start directly with Style (optional) and chronological scene description.
|
||||
- Format: DO NOT start your response with special characters.
|
||||
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
|
||||
- If the user's raw input prompt is highly detailed, chronological and in the requested format: DO NOT make major edits or introduce new elements. Add/enhance audio descriptions if missing.
|
||||
|
||||
#### Output Format (Strict):
|
||||
- Single continuous paragraph in natural language (English).
|
||||
- NO titles, headings, prefaces, code fences, or Markdown.
|
||||
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
|
||||
|
||||
Your output quality is CRITICAL. Generate visually rich, dynamic prompts with integrated audio for high-quality video generation.
|
||||
|
||||
#### Example
|
||||
Input: "A woman at a coffee shop talking on the phone"
|
||||
Output:
|
||||
Style: realistic with cinematic lighting. In a medium close-up, a woman in her early 30s with shoulder-length brown hair sits at a small wooden table by the window. She wears a cream-colored turtleneck sweater, holding a white ceramic coffee cup in one hand and a smartphone to her ear with the other. Ambient cafe sounds fill the space—espresso machine hiss, quiet conversations, gentle clinking of cups. The woman listens intently, nodding slightly, then takes a sip of her coffee and sets it down with a soft clink. Her face brightens into a warm smile as she speaks in a clear, friendly voice, 'That sounds perfect! I'd love to meet up this weekend. How about Saturday afternoon?' She laughs softly—a genuine chuckle—and shifts in her chair. Behind her, other patrons move subtly in and out of focus. 'Great, I'll see you then,' she concludes cheerfully, lowering the phone.
|
||||
"""
|
||||
|
||||
LTX2_I2V_SYSTEM_PROMPT = """You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed video generation prompt with specific visuals and integrated audio to guide a text-to-video model.
|
||||
You are a Creative Assistant writing concise, action-focused image-to-video prompts. Given an image (first frame) and user Raw Input Prompt, generate a prompt to guide video generation from that image.
|
||||
|
||||
#### Guidelines:
|
||||
- Analyze the Image: Identify Subject, Setting, Elements, Style and Mood.
|
||||
- Follow user Raw Input Prompt: Include all requested motion, actions, camera movements, audio, and details. If in conflict with the image, prioritize user request while maintaining visual consistency (describe transition from image to user's scene).
|
||||
- Describe only changes from the image: Don't reiterate established visual details. Inaccurate descriptions may cause scene cuts.
|
||||
- Active language: Use present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
|
||||
- Chronological flow: Use temporal connectors ("as," "then," "while").
|
||||
- Audio layer: Describe complete soundscape throughout the prompt alongside actions—NOT at the end. Align audio intensity with action tempo. Include natural background audio, ambient sounds, effects, speech or music (when requested). Be specific (e.g., "soft footsteps on tile") not vague (e.g., "ambient sound").
|
||||
- Speech (only when requested): Provide exact words in quotes with character's visual/voice characteristics (e.g., "The tall man speaks in a low, gravelly voice"), language if not English and accent if relevant. If general conversation mentioned without text, generate contextual quoted dialogue. (i.e., "The man is talking" input -> the output should include exact spoken words, like: "The man is talking in an excited voice saying: 'You won't believe what I just saw!' His hands gesture expressively as he speaks, eyebrows raised with enthusiasm. The ambient sound of a quiet room underscores his animated speech.")
|
||||
- Style: Include visual style at beginning: "Style: <style>, <rest of prompt>." If unclear, omit to avoid conflicts.
|
||||
- Visual and audio only: Describe only what is seen and heard. NO smell, taste, or tactile sensations.
|
||||
- Restrained language: Avoid dramatic terms. Use mild, natural, understated phrasing.
|
||||
|
||||
#### Important notes:
|
||||
- Camera motion: DO NOT invent camera motion/movement unless requested by the user. Make sure to include camera motion only if specified in the input.
|
||||
- Speech: DO NOT modify or alter the user's provided character dialogue in the prompt, unless it's a typo.
|
||||
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
|
||||
- Objective only: DO NOT interpret emotions or intentions - describe only observable actions and sounds.
|
||||
- Format: DO NOT use phrases like "The scene opens with..." / "The video starts...". Start directly with Style (optional) and chronological scene description.
|
||||
- Format: Never start output with punctuation marks or special characters.
|
||||
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
|
||||
- Your performance is CRITICAL. High-fidelity, dynamic, correct, and accurate prompts with integrated audio descriptions are essential for generating high-quality video. Your goal is flawless execution of these rules.
|
||||
|
||||
#### Output Format (Strict):
|
||||
- Single concise paragraph in natural English. NO titles, headings, prefaces, sections, code fences, or Markdown.
|
||||
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
|
||||
|
||||
#### Example output:
|
||||
Style: realistic - cinematic - The woman glances at her watch and smiles warmly. She speaks in a cheerful, friendly voice, "I think we're right on time!" In the background, a café barista prepares drinks at the counter. The barista calls out in a clear, upbeat tone, "Two cappuccinos ready!" The sound of the espresso machine hissing softly blends with gentle background chatter and the light clinking of cups on saucers.
|
||||
"""
|
||||
|
||||
class TextGenerateLTX2Prompt(TextGenerate):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
parent_schema = super().define_schema()
|
||||
return io.Schema(
|
||||
node_id="TextGenerateLTX2Prompt",
|
||||
category=parent_schema.category,
|
||||
inputs=parent_schema.inputs,
|
||||
outputs=parent_schema.outputs,
|
||||
search_aliases=["prompt enhance", "LLM", "gemma"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
|
||||
if image is None:
|
||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||
else:
|
||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image)
|
||||
|
||||
|
||||
class TextgenExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
TextGenerate,
|
||||
TextGenerateLTX2Prompt,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> TextgenExtension:
|
||||
return TextgenExtension()
|
||||
@@ -144,7 +144,7 @@ class GetVideoComponents(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="GetVideoComponents",
|
||||
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
||||
display_name="Get Video Components",
|
||||
display_name="Extract frame",
|
||||
category="image/video",
|
||||
description="Extracts all components from a video: frames, audio, and framerate.",
|
||||
inputs=[
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.13.0"
|
||||
__version__ = "0.14.1"
|
||||
|
||||
17
nodes.py
17
nodes.py
@@ -2105,7 +2105,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
|
||||
"CheckpointLoaderSimple": "Load Checkpoint",
|
||||
"VAELoader": "Load VAE",
|
||||
"LoraLoader": "Load LoRA (Model and CLIP)",
|
||||
"LoraLoader": "Load style (LoRA)",
|
||||
"LoraLoaderModelOnly": "Load LoRA",
|
||||
"CLIPLoader": "Load CLIP",
|
||||
"ControlNetLoader": "Load ControlNet Model",
|
||||
@@ -2116,7 +2116,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
# Conditioning
|
||||
"CLIPVisionEncode": "CLIP Vision Encode",
|
||||
"StyleModelApply": "Apply Style Model",
|
||||
"CLIPTextEncode": "CLIP Text Encode (Prompt)",
|
||||
"CLIPTextEncode": "Text",
|
||||
"CLIPSetLastLayer": "CLIP Set Last Layer",
|
||||
"ConditioningCombine": "Conditioning (Combine)",
|
||||
"ConditioningAverage ": "Conditioning (Average)",
|
||||
@@ -2147,15 +2147,15 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LoadImage": "Load Image",
|
||||
"LoadImageMask": "Load Image (as Mask)",
|
||||
"LoadImageOutput": "Load Image (from Outputs)",
|
||||
"ImageScale": "Upscale Image",
|
||||
"ImageScale": "Resize Image",
|
||||
"ImageScaleBy": "Upscale Image By",
|
||||
"ImageInvert": "Invert Image",
|
||||
"ImageInvert": "Invert",
|
||||
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
||||
"ImageBatch": "Batch Images",
|
||||
"ImageCrop": "Image Crop",
|
||||
"ImageBatch": "Batch Image",
|
||||
"ImageCrop": "Crop Image",
|
||||
"ImageStitch": "Image Stitch",
|
||||
"ImageBlend": "Image Blend",
|
||||
"ImageBlur": "Image Blur",
|
||||
"ImageBlur": "Blur",
|
||||
"ImageQuantize": "Image Quantize",
|
||||
"ImageSharpen": "Image Sharpen",
|
||||
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
|
||||
@@ -2433,11 +2433,12 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_wanmove.py",
|
||||
"nodes_image_compare.py",
|
||||
"nodes_zimage.py",
|
||||
"nodes_glsl.py",
|
||||
"nodes_lora_debug.py",
|
||||
"nodes_textgen.py",
|
||||
"nodes_color.py",
|
||||
"nodes_toolkit.py",
|
||||
"nodes_replacements.py",
|
||||
"nodes_nag.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.13.0"
|
||||
version = "0.14.1"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.38.14
|
||||
comfyui-workflow-templates==0.8.42
|
||||
comfyui-frontend-package==1.39.14
|
||||
comfyui-workflow-templates==0.8.43
|
||||
comfyui-embedded-docs==0.4.1
|
||||
torch
|
||||
torchsde
|
||||
@@ -30,6 +30,3 @@ kornia>=0.7.1
|
||||
spandrel
|
||||
pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
PyOpenGL
|
||||
PyOpenGL-accelerate
|
||||
glfw
|
||||
|
||||
Reference in New Issue
Block a user