Compare commits

..

14 Commits

Author SHA1 Message Date
Jedrzej Kosinski
631d4c1921 Add all_model_folders as possible entry to extra_model_paths so that all model-related paths will get mapped without needing to specify them individually 2026-02-17 22:52:36 -08:00
Terry Jia
8ad38d2073 BBox widget (#11594)
* Boundingbox widget

* code improve

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
Co-authored-by: Christian Byrne <cbyrne@comfy.org>
2026-02-17 17:13:39 -08:00
Comfy Org PR Bot
6c14f129af Bump comfyui-frontend-package to 1.39.14 (#12494)
* Bump comfyui-frontend-package to 1.39.13

* Update requirements.txt

---------

Co-authored-by: Christian Byrne <cbyrne@comfy.org>
2026-02-17 13:41:34 -08:00
rattus
58dcc97dcf ops: limit return of requants (#12506)
This check was far too broad and the dtype is not a reliable indicator
of wanting the requant (as QT returns the compute dtype as the dtype).
So explictly plumb whether fp8mm wants the requant or not.
2026-02-17 15:32:27 -05:00
comfyanonymous
19236edfa4 ComfyUI v0.14.1 2026-02-17 13:28:06 -05:00
ComfyUI Wiki
73c3f86973 chore: update workflow templates to v0.8.43 (#12507) 2026-02-17 13:25:55 -05:00
Alexander Piskun
262abf437b feat(api-nodes): add Recraft V4 nodes (#12502) 2026-02-17 13:25:44 -05:00
Alexander Piskun
5284e6bf69 feat(api-nodes): add "viduq3-turbo" model and Vidu3StartEnd node; fix the price badges (#12482) 2026-02-17 10:07:14 -08:00
chaObserv
44f8598521 Fix anima LLM adapter forward when manual cast (#12504) 2026-02-17 07:56:44 -08:00
comfyanonymous
fe52843fe5 ComfyUI v0.14.0 2026-02-17 00:39:54 -05:00
comfyanonymous
c39653163d Fix anima preprocess text embeds not using right inference dtype. (#12501) 2026-02-17 00:29:20 -05:00
comfyanonymous
18927538a1 Implement NAG on all the models based on the Flux code. (#12500)
Use the Normalized Attention Guidance node.

Flux, Flux2, Klein, Chroma, Chroma radiance, Hunyuan Video, etc..
2026-02-16 23:30:34 -05:00
Jedrzej Kosinski
8a6fbc2dc2 Allow control_after_generate to be type ControlAfterGenerate in v3 schema (#12187) 2026-02-16 22:20:21 -05:00
Alex Butler
b44fc4c589 add venv* to gitignore (#12431) 2026-02-16 22:16:19 -05:00
23 changed files with 922 additions and 936 deletions

2
.gitignore vendored
View File

@@ -11,7 +11,7 @@ extra_model_paths.yaml
/.vs
.vscode/
.idea/
venv/
venv*/
.venv/
/web/extensions/*
!/web/extensions/logging.js.example

View File

@@ -179,8 +179,8 @@ class LLMAdapter(nn.Module):
if source_attention_mask.ndim == 2:
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
x = self.in_proj(self.embed(target_input_ids))
context = source_hidden_states
x = self.in_proj(self.embed(target_input_ids, out_dtype=context.dtype))
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
position_embeddings = self.rotary_emb(x, position_ids)

View File

@@ -152,6 +152,7 @@ class Chroma(nn.Module):
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
transformer_options = transformer_options.copy()
patches_replace = transformer_options.get("patches_replace", {})
# running on sequences img
@@ -228,6 +229,7 @@ class Chroma(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if i not in self.skip_dit:

View File

@@ -196,6 +196,9 @@ class DoubleStreamBlock(nn.Module):
else:
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
transformer_patches = transformer_options.get("patches", {})
extra_options = transformer_options.copy()
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
@@ -224,6 +227,12 @@ class DoubleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
if "attn1_output_patch" in transformer_patches:
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
patch = transformer_patches["attn1_output_patch"]
for p in patch:
attn = p(attn, extra_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
@@ -303,6 +312,9 @@ class SingleStreamBlock(nn.Module):
else:
mod = vec
transformer_patches = transformer_options.get("patches", {})
extra_options = transformer_options.copy()
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
@@ -312,6 +324,12 @@ class SingleStreamBlock(nn.Module):
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
if "attn1_output_patch" in transformer_patches:
patch = transformer_patches["attn1_output_patch"]
for p in patch:
attn = p(attn, extra_options)
# compute activation in mlp stream, cat again and run second linear layer
if self.yak_mlp:
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]

View File

@@ -142,6 +142,7 @@ class Flux(nn.Module):
attn_mask: Tensor = None,
) -> Tensor:
transformer_options = transformer_options.copy()
patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
@@ -231,6 +232,7 @@ class Flux(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:

View File

@@ -304,6 +304,7 @@ class HunyuanVideo(nn.Module):
control=None,
transformer_options={},
) -> Tensor:
transformer_options = transformer_options.copy()
patches_replace = transformer_options.get("patches_replace", {})
initial_shape = list(img.shape)
@@ -416,6 +417,7 @@ class HunyuanVideo(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:

View File

@@ -178,10 +178,7 @@ class BaseModel(torch.nn.Module):
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
context = c_crossattn
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
dtype = self.get_dtype_inference()
xc = xc.to(dtype)
device = xc.device
@@ -218,6 +215,13 @@ class BaseModel(torch.nn.Module):
def get_dtype(self):
return self.diffusion_model.dtype
def get_dtype_inference(self):
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
return dtype
def encode_adm(self, **kwargs):
return None
@@ -372,9 +376,7 @@ class BaseModel(torch.nn.Module):
input_shapes += shape
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
dtype = self.get_dtype_inference()
#TODO: this needs to be tweaked
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
@@ -1165,7 +1167,7 @@ class Anima(BaseModel):
t5xxl_ids = t5xxl_ids.unsqueeze(0)
if torch.is_inference_mode_enabled(): # if not we are training
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference()))
else:
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)

View File

@@ -406,13 +406,16 @@ class ModelPatcher:
def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)
def disable_model_cfg1_optimization(self):
self.model_options["disable_cfg1_optimization"] = True
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True
self.disable_model_cfg1_optimization()
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)

View File

@@ -79,7 +79,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
offload_stream = None
xfer_dest = None
@@ -170,10 +170,10 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
x = lowvram_fn(x)
if (isinstance(orig, QuantizedTensor) and
(orig.dtype == dtype and len(fns) == 0 or update_weight)):
(want_requant and len(fns) == 0 or update_weight)):
seed = comfy.utils.string_to_seed(s.seed_key)
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
if orig.dtype == dtype and len(fns) == 0:
if want_requant and len(fns) == 0:
#The layer actually wants our freshly saved QT
x = y
elif update_weight:
@@ -194,7 +194,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
return weight, bias, (offload_stream, device if signature is not None else None, None)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
@@ -212,7 +212,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"):
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
@@ -850,8 +850,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input, compute_dtype=None):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
@@ -881,8 +881,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
output = self.forward_comfy_cast_weights(input, compute_dtype)
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
# Reshape output back to 3D if input was 3D
if reshaped_3d:

View File

@@ -75,6 +75,12 @@ class NumberDisplay(str, Enum):
slider = "slider"
class ControlAfterGenerate(str, Enum):
fixed = "fixed"
increment = "increment"
decrement = "decrement"
randomize = "randomize"
class _ComfyType(ABC):
Type = Any
io_type: str = None
@@ -263,7 +269,7 @@ class Int(ComfyTypeIO):
class Input(WidgetInput):
'''Integer input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool | ControlAfterGenerate=None,
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.min = min
@@ -345,7 +351,7 @@ class Combo(ComfyTypeIO):
tooltip: str=None,
lazy: bool=None,
default: str | int | Enum = None,
control_after_generate: bool=None,
control_after_generate: bool | ControlAfterGenerate=None,
upload: UploadType=None,
image_folder: FolderType=None,
remote: RemoteOptions=None,
@@ -389,7 +395,7 @@ class MultiCombo(ComfyTypeI):
Type = list[str]
class Input(Combo.Input):
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool | ControlAfterGenerate=None,
socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced)
self.multiselect = True
@@ -1203,6 +1209,30 @@ class Color(ComfyTypeIO):
def as_dict(self):
return super().as_dict()
@comfytype(io_type="BOUNDING_BOX")
class BoundingBox(ComfyTypeIO):
class BoundingBoxDict(TypedDict):
x: int
y: int
width: int
height: int
Type = BoundingBoxDict
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: dict=None, component: str=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless)
self.component = component
if default is None:
self.default = {"x": 0, "y": 0, "width": 512, "height": 512}
def as_dict(self):
d = super().as_dict()
if self.component:
d["component"] = self.component
return d
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func
@@ -2097,6 +2127,7 @@ __all__ = [
"UploadType",
"RemoteOptions",
"NumberDisplay",
"ControlAfterGenerate",
"comfytype",
"Custom",
@@ -2183,5 +2214,6 @@ __all__ = [
"ImageCompare",
"PriceBadgeDepends",
"PriceBadge",
"BoundingBox",
"NodeReplace",
]

View File

@@ -198,11 +198,6 @@ dict_recraft_substyles_v3 = {
}
class RecraftModel(str, Enum):
recraftv3 = 'recraftv3'
recraftv2 = 'recraftv2'
class RecraftImageSize(str, Enum):
res_1024x1024 = '1024x1024'
res_1365x1024 = '1365x1024'
@@ -221,6 +216,41 @@ class RecraftImageSize(str, Enum):
res_1707x1024 = '1707x1024'
RECRAFT_V4_SIZES = [
"1024x1024",
"1536x768",
"768x1536",
"1280x832",
"832x1280",
"1216x896",
"896x1216",
"1152x896",
"896x1152",
"832x1344",
"1280x896",
"896x1280",
"1344x768",
"768x1344",
]
RECRAFT_V4_PRO_SIZES = [
"2048x2048",
"3072x1536",
"1536x3072",
"2560x1664",
"1664x2560",
"2432x1792",
"1792x2432",
"2304x1792",
"1792x2304",
"1664x2688",
"1434x1024",
"1024x1434",
"2560x1792",
"1792x2560",
]
class RecraftColorObject(BaseModel):
rgb: list[int] = Field(..., description='An array of 3 integer values in range of 0...255 defining RGB Color Model')
@@ -234,17 +264,16 @@ class RecraftControlsObject(BaseModel):
class RecraftImageGenerationRequest(BaseModel):
prompt: str = Field(..., description='The text prompt describing the image to generate')
size: RecraftImageSize | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
size: str | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
n: int = Field(..., description='The number of images to generate')
negative_prompt: str | None = Field(None, description='A text description of undesired elements on an image')
model: RecraftModel | None = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
model: str = Field(...)
style: str | None = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
substyle: str | None = Field(None, description='The substyle to apply to the generated image, depending on the style input')
controls: RecraftControlsObject | None = Field(None, description='A set of custom parameters to tweak generation process')
style_id: str | None = Field(None, description='Use a previously uploaded style as a reference; UUID')
strength: float | None = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
random_seed: int | None = Field(None, description="Seed for video generation")
# text_layout
class RecraftReturnedObject(BaseModel):

View File

@@ -1,5 +1,4 @@
from io import BytesIO
from typing import Optional, Union
import aiohttp
import torch
@@ -9,6 +8,8 @@ from typing_extensions import override
from comfy.utils import ProgressBar
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.recraft import (
RECRAFT_V4_PRO_SIZES,
RECRAFT_V4_SIZES,
RecraftColor,
RecraftColorChain,
RecraftControls,
@@ -18,7 +19,6 @@ from comfy_api_nodes.apis.recraft import (
RecraftImageGenerationResponse,
RecraftImageSize,
RecraftIO,
RecraftModel,
RecraftStyle,
RecraftStyleV3,
get_v3_substyles,
@@ -39,7 +39,7 @@ async def handle_recraft_file_request(
cls: type[IO.ComfyNode],
image: torch.Tensor,
path: str,
mask: Optional[torch.Tensor] = None,
mask: torch.Tensor | None = None,
total_pixels: int = 4096 * 4096,
timeout: int = 1024,
request=None,
@@ -73,11 +73,11 @@ async def handle_recraft_file_request(
def recraft_multipart_parser(
data,
parent_key=None,
formatter: Optional[type[callable]] = None,
converted_to_check: Optional[list[list]] = None,
formatter: type[callable] | None = None,
converted_to_check: list[list] | None = None,
is_list: bool = False,
return_mode: str = "formdata", # "dict" | "formdata"
) -> Union[dict, aiohttp.FormData]:
) -> dict | aiohttp.FormData:
"""
Formats data such that multipart/form-data will work with aiohttp library when both files and data are present.
@@ -309,7 +309,7 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode):
node_id="RecraftStyleV3InfiniteStyleLibrary",
display_name="Recraft Style - Infinite Style Library",
category="api node/image/Recraft",
description="Select style based on preexisting UUID from Recraft's Infinite Style Library.",
description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.",
inputs=[
IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."),
],
@@ -485,7 +485,7 @@ class RecraftTextToImageNode(IO.ComfyNode):
data=RecraftImageGenerationRequest(
prompt=prompt,
negative_prompt=negative_prompt,
model=RecraftModel.recraftv3,
model="recraftv3",
size=size,
n=n,
style=recraft_style.style,
@@ -598,7 +598,7 @@ class RecraftImageToImageNode(IO.ComfyNode):
request = RecraftImageGenerationRequest(
prompt=prompt,
negative_prompt=negative_prompt,
model=RecraftModel.recraftv3,
model="recraftv3",
n=n,
strength=round(strength, 2),
style=recraft_style.style,
@@ -698,7 +698,7 @@ class RecraftImageInpaintingNode(IO.ComfyNode):
request = RecraftImageGenerationRequest(
prompt=prompt,
negative_prompt=negative_prompt,
model=RecraftModel.recraftv3,
model="recraftv3",
n=n,
style=recraft_style.style,
substyle=recraft_style.substyle,
@@ -810,7 +810,7 @@ class RecraftTextToVectorNode(IO.ComfyNode):
data=RecraftImageGenerationRequest(
prompt=prompt,
negative_prompt=negative_prompt,
model=RecraftModel.recraftv3,
model="recraftv3",
size=size,
n=n,
style=recraft_style.style,
@@ -933,7 +933,7 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode):
request = RecraftImageGenerationRequest(
prompt=prompt,
negative_prompt=negative_prompt,
model=RecraftModel.recraftv3,
model="recraftv3",
n=n,
style=recraft_style.style,
substyle=recraft_style.substyle,
@@ -1078,6 +1078,252 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode):
)
class RecraftV4TextToImageNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="RecraftV4TextToImageNode",
display_name="Recraft V4 Text to Image",
category="api node/image/Recraft",
description="Generates images using Recraft V4 or V4 Pro models.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
tooltip="Prompt for the image generation. Maximum 10,000 characters.",
),
IO.String.Input(
"negative_prompt",
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"recraftv4",
[
IO.Combo.Input(
"size",
options=RECRAFT_V4_SIZES,
default="1024x1024",
tooltip="The size of the generated image.",
),
],
),
IO.DynamicCombo.Option(
"recraftv4_pro",
[
IO.Combo.Input(
"size",
options=RECRAFT_V4_PRO_SIZES,
default="2048x2048",
tooltip="The size of the generated image.",
),
],
),
],
tooltip="The model to use for generation.",
),
IO.Int.Input(
"n",
default=1,
min=1,
max=6,
tooltip="The number of images to generate.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
IO.Custom(RecraftIO.CONTROLS).Input(
"recraft_controls",
tooltip="Optional additional controls over the generation via the Recraft Controls node.",
optional=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]),
expr="""
(
$prices := {"recraftv4": 0.04, "recraftv4_pro": 0.25};
{"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
negative_prompt: str,
model: dict,
n: int,
seed: int,
recraft_controls: RecraftControls | None = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"),
response_model=RecraftImageGenerationResponse,
data=RecraftImageGenerationRequest(
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt else None,
model=model["model"],
size=model["size"],
n=n,
controls=recraft_controls.create_api_model() if recraft_controls else None,
),
max_retries=1,
)
images = []
for data in response.data:
with handle_recraft_image_output():
image = bytesio_to_image_tensor(await download_url_as_bytesio(data.url, timeout=1024))
if len(image.shape) < 4:
image = image.unsqueeze(0)
images.append(image)
return IO.NodeOutput(torch.cat(images, dim=0))
class RecraftV4TextToVectorNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="RecraftV4TextToVectorNode",
display_name="Recraft V4 Text to Vector",
category="api node/image/Recraft",
description="Generates SVG using Recraft V4 or V4 Pro models.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
tooltip="Prompt for the image generation. Maximum 10,000 characters.",
),
IO.String.Input(
"negative_prompt",
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"recraftv4",
[
IO.Combo.Input(
"size",
options=RECRAFT_V4_SIZES,
default="1024x1024",
tooltip="The size of the generated image.",
),
],
),
IO.DynamicCombo.Option(
"recraftv4_pro",
[
IO.Combo.Input(
"size",
options=RECRAFT_V4_PRO_SIZES,
default="2048x2048",
tooltip="The size of the generated image.",
),
],
),
],
tooltip="The model to use for generation.",
),
IO.Int.Input(
"n",
default=1,
min=1,
max=6,
tooltip="The number of images to generate.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
IO.Custom(RecraftIO.CONTROLS).Input(
"recraft_controls",
tooltip="Optional additional controls over the generation via the Recraft Controls node.",
optional=True,
),
],
outputs=[
IO.SVG.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]),
expr="""
(
$prices := {"recraftv4": 0.08, "recraftv4_pro": 0.30};
{"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
negative_prompt: str,
model: dict,
n: int,
seed: int,
recraft_controls: RecraftControls | None = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"),
response_model=RecraftImageGenerationResponse,
data=RecraftImageGenerationRequest(
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt else None,
model=model["model"],
size=model["size"],
n=n,
style="vector_illustration",
substyle=None,
controls=recraft_controls.create_api_model() if recraft_controls else None,
),
max_retries=1,
)
svg_data = []
for data in response.data:
svg_data.append(await download_url_as_bytesio(data.url, timeout=1024))
return IO.NodeOutput(SVG(svg_data))
class RecraftExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@@ -1098,6 +1344,8 @@ class RecraftExtension(ComfyExtension):
RecraftCreateStyleNode,
RecraftColorRGBNode,
RecraftControlsNode,
RecraftV4TextToImageNode,
RecraftV4TextToVectorNode,
]

View File

@@ -54,6 +54,7 @@ async def execute_task(
response_model=TaskStatusResponse,
status_extractor=lambda r: r.state,
progress_extractor=lambda r: r.progress,
price_extractor=lambda r: r.credits * 0.005 if r.credits is not None else None,
max_poll_attempts=max_poll_attempts,
)
if not response.creations:
@@ -1306,6 +1307,36 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
),
],
),
IO.DynamicCombo.Option(
"viduq3-turbo",
[
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16", "3:4", "4:3", "1:1"],
tooltip="The aspect ratio of the output video.",
),
IO.Combo.Input(
"resolution",
options=["720p", "1080p"],
tooltip="Resolution of the output video.",
),
IO.Int.Input(
"duration",
default=5,
min=1,
max=16,
step=1,
display_mode=IO.NumberDisplay.slider,
tooltip="Duration of the output video in seconds.",
),
IO.Boolean.Input(
"audio",
default=False,
tooltip="When enabled, outputs video with sound "
"(including dialogue and sound effects).",
),
],
),
],
tooltip="Model to use for video generation.",
),
@@ -1334,13 +1365,20 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]),
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
expr="""
(
$res := $lookup(widgets, "model.resolution");
$base := $lookup({"720p": 0.075, "1080p": 0.1}, $res);
$perSec := $lookup({"720p": 0.025, "1080p": 0.05}, $res);
{"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)}
$d := $lookup(widgets, "model.duration");
$contains(widgets.model, "turbo")
? (
$rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
{"type":"usd","usd": $rate * $d}
)
: (
$rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res);
{"type":"usd","usd": $rate * $d}
)
)
""",
),
@@ -1409,6 +1447,31 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
),
],
),
IO.DynamicCombo.Option(
"viduq3-turbo",
[
IO.Combo.Input(
"resolution",
options=["720p", "1080p"],
tooltip="Resolution of the output video.",
),
IO.Int.Input(
"duration",
default=5,
min=1,
max=16,
step=1,
display_mode=IO.NumberDisplay.slider,
tooltip="Duration of the output video in seconds.",
),
IO.Boolean.Input(
"audio",
default=False,
tooltip="When enabled, outputs video with sound "
"(including dialogue and sound effects).",
),
],
),
],
tooltip="Model to use for video generation.",
),
@@ -1442,13 +1505,20 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]),
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
expr="""
(
$res := $lookup(widgets, "model.resolution");
$base := $lookup({"720p": 0.075, "1080p": 0.275, "2k": 0.35}, $res);
$perSec := $lookup({"720p": 0.05, "1080p": 0.075, "2k": 0.075}, $res);
{"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)}
$d := $lookup(widgets, "model.duration");
$contains(widgets.model, "turbo")
? (
$rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
{"type":"usd","usd": $rate * $d}
)
: (
$rate := $lookup({"720p": 0.15, "1080p": 0.16, "2k": 0.2}, $res);
{"type":"usd","usd": $rate * $d}
)
)
""",
),
@@ -1481,6 +1551,145 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
class Vidu3StartEndToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Vidu3StartEndToVideoNode",
display_name="Vidu Q3 Start/End Frame-to-Video Generation",
category="api node/video/Vidu",
description="Generate a video from a start frame, an end frame, and a prompt.",
inputs=[
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"viduq3-pro",
[
IO.Combo.Input(
"resolution",
options=["720p", "1080p"],
tooltip="Resolution of the output video.",
),
IO.Int.Input(
"duration",
default=5,
min=1,
max=16,
step=1,
display_mode=IO.NumberDisplay.slider,
tooltip="Duration of the output video in seconds.",
),
IO.Boolean.Input(
"audio",
default=False,
tooltip="When enabled, outputs video with sound "
"(including dialogue and sound effects).",
),
],
),
IO.DynamicCombo.Option(
"viduq3-turbo",
[
IO.Combo.Input(
"resolution",
options=["720p", "1080p"],
tooltip="Resolution of the output video.",
),
IO.Int.Input(
"duration",
default=5,
min=1,
max=16,
step=1,
display_mode=IO.NumberDisplay.slider,
tooltip="Duration of the output video in seconds.",
),
IO.Boolean.Input(
"audio",
default=False,
tooltip="When enabled, outputs video with sound "
"(including dialogue and sound effects).",
),
],
),
],
tooltip="Model to use for video generation.",
),
IO.Image.Input("first_frame"),
IO.Image.Input("end_frame"),
IO.String.Input(
"prompt",
multiline=True,
tooltip="Prompt description (max 2000 characters).",
),
IO.Int.Input(
"seed",
default=1,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
expr="""
(
$res := $lookup(widgets, "model.resolution");
$d := $lookup(widgets, "model.duration");
$contains(widgets.model, "turbo")
? (
$rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
{"type":"usd","usd": $rate * $d}
)
: (
$rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res);
{"type":"usd","usd": $rate * $d}
)
)
""",
),
)
@classmethod
async def execute(
cls,
model: dict,
first_frame: Input.Image,
end_frame: Input.Image,
prompt: str,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, max_length=2000)
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
payload = TaskCreationRequest(
model=model["model"],
prompt=prompt,
duration=model["duration"],
seed=seed,
resolution=model["resolution"],
audio=model["audio"],
images=[
(await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
for frame in (first_frame, end_frame)
],
)
results = await execute_task(cls, VIDU_START_END_VIDEO, payload)
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
class ViduExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@@ -1497,6 +1706,7 @@ class ViduExtension(ComfyExtension):
ViduMultiFrameVideoNode,
Vidu3TextToVideoNode,
Vidu3ImageToVideoNode,
Vidu3StartEndToVideoNode,
]

View File

@@ -1,876 +0,0 @@
import os
import sys
import re
import logging
import ctypes.util
import importlib.util
from typing import TypedDict
import numpy as np
import torch
import nodes
from comfy_api.latest import ComfyExtension, io, ui
from typing_extensions import override
from utils.install_util import get_missing_requirements_message
logger = logging.getLogger(__name__)
def _check_opengl_availability():
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
logger.debug("_check_opengl_availability: starting")
missing = []
# Check Python packages (using find_spec to avoid importing)
logger.debug("_check_opengl_availability: checking for glfw package")
if importlib.util.find_spec("glfw") is None:
missing.append("glfw")
logger.debug("_check_opengl_availability: checking for OpenGL package")
if importlib.util.find_spec("OpenGL") is None:
missing.append("PyOpenGL")
if missing:
raise RuntimeError(
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
)
# On Linux without display, check if headless backends are available
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
if sys.platform.startswith("linux"):
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
if not has_display:
# Check for EGL or OSMesa libraries
logger.debug("_check_opengl_availability: checking for EGL library")
has_egl = ctypes.util.find_library("EGL")
logger.debug("_check_opengl_availability: checking for OSMesa library")
has_osmesa = ctypes.util.find_library("OSMesa")
# Error disabled for CI as it fails this check
# if not has_egl and not has_osmesa:
# raise RuntimeError(
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
# "See error below for installation instructions."
# )
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
logger.debug("_check_opengl_availability: completed")
# Run early check at import time
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
_check_opengl_availability()
# OpenGL modules - initialized lazily when context is created
gl = None
glfw = None
EGL = None
def _import_opengl():
"""Import OpenGL module. Called after context is created."""
global gl
if gl is None:
logger.debug("_import_opengl: importing OpenGL.GL")
import OpenGL.GL as _gl
gl = _gl
logger.debug("_import_opengl: import completed")
return gl
class SizeModeInput(TypedDict):
size_mode: str
width: int
height: int
MAX_IMAGES = 5 # u_image0-4
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
# Vertex shader using gl_VertexID trick - no VBO needed.
# Draws a single triangle that covers the entire screen:
#
# (-1,3)
# /|
# / | <- visible area is the unit square from (-1,-1) to (1,1)
# / | parts outside get clipped away
# (-1,-1)---(3,-1)
#
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
VERTEX_SHADER = """#version 330 core
out vec2 v_texCoord;
void main() {
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
v_texCoord = verts[gl_VertexID] * 0.5 + 0.5;
gl_Position = vec4(verts[gl_VertexID], 0, 1);
}
"""
DEFAULT_FRAGMENT_SHADER = """#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
void main() {
fragColor0 = texture(u_image0, v_texCoord);
}
"""
def _convert_es_to_desktop(source: str) -> str:
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
# Remove any existing #version directive
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
# Remove precision qualifiers (not needed in desktop GLSL)
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
# Prepend desktop GLSL version
return "#version 330 core\n" + source
def _detect_output_count(source: str) -> int:
"""Detect how many fragColor outputs are used in the shader.
Returns the count of outputs needed (1 to MAX_OUTPUTS).
"""
matches = re.findall(r"fragColor(\d+)", source)
if not matches:
return 1 # Default to 1 output if none found
max_index = max(int(m) for m in matches)
return min(max_index + 1, MAX_OUTPUTS)
def _init_glfw():
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
logger.debug("_init_glfw: starting")
# On macOS, glfw.init() must be called from main thread or it hangs forever
if sys.platform == "darwin":
logger.debug("_init_glfw: skipping on macOS")
raise RuntimeError("GLFW backend not supported on macOS")
logger.debug("_init_glfw: importing glfw module")
import glfw as _glfw
logger.debug("_init_glfw: calling glfw.init()")
if not _glfw.init():
raise RuntimeError("glfw.init() failed")
try:
logger.debug("_init_glfw: setting window hints")
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
logger.debug("_init_glfw: calling create_window()")
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
if not window:
raise RuntimeError("glfw.create_window() failed")
logger.debug("_init_glfw: calling make_context_current()")
_glfw.make_context_current(window)
logger.debug("_init_glfw: completed successfully")
return window, _glfw
except Exception:
logger.debug("_init_glfw: failed, terminating glfw")
_glfw.terminate()
raise
def _init_egl():
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
logger.debug("_init_egl: starting")
from OpenGL import EGL as _EGL
from OpenGL.EGL import (
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
eglTerminate, eglDestroyContext, eglDestroySurface,
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
)
logger.debug("_init_egl: imports completed")
display = None
context = None
surface = None
try:
logger.debug("_init_egl: calling eglGetDisplay()")
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
if display == _EGL.EGL_NO_DISPLAY:
raise RuntimeError("eglGetDisplay() failed")
logger.debug("_init_egl: calling eglInitialize()")
major, minor = _EGL.EGLint(), _EGL.EGLint()
if not eglInitialize(display, major, minor):
display = None # Not initialized, don't terminate
raise RuntimeError("eglInitialize() failed")
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
config_attribs = [
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
EGL_DEPTH_SIZE, 0, EGL_NONE
]
configs = (_EGL.EGLConfig * 1)()
num_configs = _EGL.EGLint()
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
raise RuntimeError("eglChooseConfig() failed")
config = configs[0]
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
if not eglBindAPI(EGL_OPENGL_API):
raise RuntimeError("eglBindAPI() failed")
logger.debug("_init_egl: calling eglCreateContext()")
context_attribs = [
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
EGL_NONE
]
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
if context == EGL_NO_CONTEXT:
raise RuntimeError("eglCreateContext() failed")
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
if surface == _EGL.EGL_NO_SURFACE:
raise RuntimeError("eglCreatePbufferSurface() failed")
logger.debug("_init_egl: calling eglMakeCurrent()")
if not eglMakeCurrent(display, surface, surface, context):
raise RuntimeError("eglMakeCurrent() failed")
logger.debug("_init_egl: completed successfully")
return display, context, surface, _EGL
except Exception:
logger.debug("_init_egl: failed, cleaning up")
# Clean up any resources on failure
if surface is not None:
eglDestroySurface(display, surface)
if context is not None:
eglDestroyContext(display, context)
if display is not None:
eglTerminate(display)
raise
def _init_osmesa():
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
import ctypes
logger.debug("_init_osmesa: starting")
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
logger.debug("_init_osmesa: importing OpenGL.osmesa")
from OpenGL import GL as _gl
from OpenGL.osmesa import (
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
OSMESA_RGBA,
)
logger.debug("_init_osmesa: imports completed")
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
if not ctx:
raise RuntimeError("OSMesaCreateContextExt() failed")
width, height = 64, 64
buffer = (ctypes.c_ubyte * (width * height * 4))()
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
OSMesaDestroyContext(ctx)
raise RuntimeError("OSMesaMakeCurrent() failed")
logger.debug("_init_osmesa: completed successfully")
return ctx, buffer
def _init_cgl():
"""Initialize CGL (macOS native OpenGL). Returns (cgl_context, opengl_lib). Raises RuntimeError on failure."""
import ctypes
import ctypes.util
logger.debug("_init_cgl: starting")
opengl_path = ctypes.util.find_library("OpenGL")
if not opengl_path:
raise RuntimeError("Could not find OpenGL framework")
opengl = ctypes.cdll.LoadLibrary(opengl_path)
CGLPixelFormatObj = ctypes.c_void_p
CGLContextObj = ctypes.c_void_p
kCGLPFAOpenGLProfile = 99
kCGLOGLPVersion_3_2_Core = 0x3200
kCGLPFAAccelerated = 73
kCGLPFAColorSize = 8
kCGLPFAAllowOfflineRenderers = 96
attrs = (ctypes.c_int * 7)(
kCGLPFAOpenGLProfile, kCGLOGLPVersion_3_2_Core,
kCGLPFAAccelerated,
kCGLPFAColorSize, 32,
kCGLPFAAllowOfflineRenderers,
0,
)
pix_fmt = CGLPixelFormatObj()
npix = ctypes.c_int(0)
err = opengl.CGLChoosePixelFormat(attrs, ctypes.byref(pix_fmt), ctypes.byref(npix))
if err != 0 or not pix_fmt:
raise RuntimeError(f"CGLChoosePixelFormat() failed with error {err}")
ctx = CGLContextObj()
err = opengl.CGLCreateContext(pix_fmt, None, ctypes.byref(ctx))
opengl.CGLDestroyPixelFormat(pix_fmt)
if err != 0 or not ctx:
raise RuntimeError(f"CGLCreateContext() failed with error {err}")
err = opengl.CGLSetCurrentContext(ctx)
if err != 0:
opengl.CGLDestroyContext(ctx)
raise RuntimeError(f"CGLSetCurrentContext() failed with error {err}")
logger.debug("_init_cgl: completed successfully")
return ctx, opengl
class GLContext:
"""Manages OpenGL context and resources for shader execution.
Tries backends in order: GLFW (desktop) → CGL (macOS) → EGL (headless GPU) → OSMesa (software).
"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if GLContext._initialized:
logger.debug("GLContext.__init__: already initialized, skipping")
return
logger.debug("GLContext.__init__: starting initialization")
global glfw, EGL
import time
start = time.perf_counter()
self._backend = None
self._window = None
self._cgl_ctx = None
self._cgl_lib = None
self._egl_display = None
self._egl_context = None
self._egl_surface = None
self._osmesa_ctx = None
self._osmesa_buffer = None
self._vao = None
# Try backends in order: GLFW → CGL (macOS) → EGL (non-macOS) → OSMesa
errors = []
logger.debug("GLContext.__init__: trying GLFW backend")
try:
self._window, glfw = _init_glfw()
self._backend = "glfw"
logger.debug("GLContext.__init__: GLFW backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
errors.append(("GLFW", e))
if self._backend is None and sys.platform == "darwin":
logger.debug("GLContext.__init__: trying CGL backend")
try:
self._cgl_ctx, self._cgl_lib = _init_cgl()
self._backend = "cgl"
logger.debug("GLContext.__init__: CGL backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: CGL backend failed: {e}")
errors.append(("CGL", e))
# Skip EGL on macOS — DarwinPlatform doesn't support EGL, and importing
# it poisons PyOpenGL's platform selection, preventing OSMesa from working.
if self._backend is None and sys.platform != "darwin":
logger.debug("GLContext.__init__: trying EGL backend")
try:
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
self._backend = "egl"
logger.debug("GLContext.__init__: EGL backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
errors.append(("EGL", e))
if self._backend is None:
logger.debug("GLContext.__init__: trying OSMesa backend")
try:
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
self._backend = "osmesa"
logger.debug("GLContext.__init__: OSMesa backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
errors.append(("OSMesa", e))
if self._backend is None:
if sys.platform == "win32":
platform_help = (
"Windows: Ensure GPU drivers are installed and display is available.\n"
" CPU-only/headless mode is not supported on Windows."
)
elif sys.platform == "darwin":
platform_help = (
"macOS: CGL context creation failed.\n"
" Ensure macOS OpenGL framework is available.\n"
" Requires: pip install PyOpenGL PyOpenGL-accelerate"
)
else:
platform_help = (
"Linux: Install one of these backends:\n"
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
" Headless (CPU): sudo apt install libosmesa6"
)
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
raise RuntimeError(
f"Failed to create OpenGL context.\n\n"
f"Backend errors:\n{error_details}\n\n"
f"{platform_help}"
)
# Now import OpenGL.GL (after context is current)
logger.debug("GLContext.__init__: importing OpenGL.GL")
_import_opengl()
# Create VAO (required for core profile, but OSMesa may use compat profile)
logger.debug("GLContext.__init__: creating VAO")
try:
vao = gl.glGenVertexArrays(1)
gl.glBindVertexArray(vao)
self._vao = vao # Only store after successful bind
logger.debug("GLContext.__init__: VAO created successfully")
except Exception as e:
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
# OSMesa with older Mesa may not support VAOs
# Clean up if we created but couldn't bind
if vao:
try:
gl.glDeleteVertexArrays(1, [vao])
except Exception:
pass
elapsed = (time.perf_counter() - start) * 1000
# Log device info
renderer = gl.glGetString(gl.GL_RENDERER)
vendor = gl.glGetString(gl.GL_VENDOR)
version = gl.glGetString(gl.GL_VERSION)
renderer = renderer.decode() if renderer else "Unknown"
vendor = vendor.decode() if vendor else "Unknown"
version = version.decode() if version else "Unknown"
GLContext._initialized = True
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
def make_current(self):
if self._backend == "glfw":
glfw.make_context_current(self._window)
elif self._backend == "cgl":
self._cgl_lib.CGLSetCurrentContext(self._cgl_ctx)
elif self._backend == "egl":
from OpenGL.EGL import eglMakeCurrent
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
elif self._backend == "osmesa":
from OpenGL.osmesa import OSMesaMakeCurrent
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
if self._vao is not None:
gl.glBindVertexArray(self._vao)
def _compile_shader(source: str, shader_type: int) -> int:
"""Compile a shader and return its ID."""
shader = gl.glCreateShader(shader_type)
gl.glShaderSource(shader, source)
gl.glCompileShader(shader)
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
error = gl.glGetShaderInfoLog(shader).decode()
gl.glDeleteShader(shader)
raise RuntimeError(f"Shader compilation failed:\n{error}")
return shader
def _create_program(vertex_source: str, fragment_source: str) -> int:
"""Create and link a shader program."""
vertex_shader = _compile_shader(vertex_source, gl.GL_VERTEX_SHADER)
try:
fragment_shader = _compile_shader(fragment_source, gl.GL_FRAGMENT_SHADER)
except RuntimeError:
gl.glDeleteShader(vertex_shader)
raise
program = gl.glCreateProgram()
gl.glAttachShader(program, vertex_shader)
gl.glAttachShader(program, fragment_shader)
gl.glLinkProgram(program)
gl.glDeleteShader(vertex_shader)
gl.glDeleteShader(fragment_shader)
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
error = gl.glGetProgramInfoLog(program).decode()
gl.glDeleteProgram(program)
raise RuntimeError(f"Program linking failed:\n{error}")
return program
def _render_shader_batch(
fragment_code: str,
width: int,
height: int,
image_batches: list[list[np.ndarray]],
floats: list[float],
ints: list[int],
) -> list[list[np.ndarray]]:
"""
Render a fragment shader for multiple batches efficiently.
Compiles shader once, reuses framebuffer/textures across batches.
Args:
fragment_code: User's fragment shader code
width: Output width
height: Output height
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
floats: List of float uniforms
ints: List of int uniforms
Returns:
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
"""
if not image_batches:
return []
ctx = GLContext()
ctx.make_current()
# Convert from GLSL ES to desktop GLSL 330
fragment_source = _convert_es_to_desktop(fragment_code)
# Detect how many outputs the shader actually uses
num_outputs = _detect_output_count(fragment_code)
# Track resources for cleanup
program = None
fbo = None
output_textures = []
input_textures = []
num_inputs = len(image_batches[0])
try:
# Compile shaders (once for all batches)
try:
program = _create_program(VERTEX_SHADER, fragment_source)
except RuntimeError:
logger.error(f"Fragment shader:\n{fragment_source}")
raise
gl.glUseProgram(program)
# Create framebuffer with only the needed color attachments
fbo = gl.glGenFramebuffers(1)
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
draw_buffers = []
for i in range(num_outputs):
tex = gl.glGenTextures(1)
output_textures.append(tex)
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, width, height, 0, gl.GL_RGBA, gl.GL_FLOAT, None)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0 + i, gl.GL_TEXTURE_2D, tex, 0)
draw_buffers.append(gl.GL_COLOR_ATTACHMENT0 + i)
gl.glDrawBuffers(num_outputs, draw_buffers)
if gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) != gl.GL_FRAMEBUFFER_COMPLETE:
raise RuntimeError("Framebuffer is not complete")
# Create input textures (reused for all batches)
for i in range(num_inputs):
tex = gl.glGenTextures(1)
input_textures.append(tex)
gl.glActiveTexture(gl.GL_TEXTURE0 + i)
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
loc = gl.glGetUniformLocation(program, f"u_image{i}")
if loc >= 0:
gl.glUniform1i(loc, i)
# Set static uniforms (once for all batches)
loc = gl.glGetUniformLocation(program, "u_resolution")
if loc >= 0:
gl.glUniform2f(loc, float(width), float(height))
for i, v in enumerate(floats):
loc = gl.glGetUniformLocation(program, f"u_float{i}")
if loc >= 0:
gl.glUniform1f(loc, v)
for i, v in enumerate(ints):
loc = gl.glGetUniformLocation(program, f"u_int{i}")
if loc >= 0:
gl.glUniform1i(loc, v)
gl.glViewport(0, 0, width, height)
gl.glDisable(gl.GL_BLEND) # Ensure no alpha blending - write output directly
# Process each batch
all_batch_outputs = []
for images in image_batches:
# Update input textures with this batch's images
for i, img in enumerate(images):
gl.glActiveTexture(gl.GL_TEXTURE0 + i)
gl.glBindTexture(gl.GL_TEXTURE_2D, input_textures[i])
# Flip vertically for GL coordinates, ensure RGBA
h, w, c = img.shape
if c == 3:
img_upload = np.empty((h, w, 4), dtype=np.float32)
img_upload[:, :, :3] = img[::-1, :, :]
img_upload[:, :, 3] = 1.0
else:
img_upload = np.ascontiguousarray(img[::-1, :, :])
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, w, h, 0, gl.GL_RGBA, gl.GL_FLOAT, img_upload)
# Render
gl.glClearColor(0, 0, 0, 0)
gl.glClear(gl.GL_COLOR_BUFFER_BIT)
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
# Read back outputs for this batch
# (glGetTexImage is synchronous, implicitly waits for rendering)
batch_outputs = []
for tex in output_textures:
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
batch_outputs.append(np.ascontiguousarray(img[::-1, :, :]))
# Pad with black images for unused outputs
black_img = np.zeros((height, width, 4), dtype=np.float32)
for _ in range(num_outputs, MAX_OUTPUTS):
batch_outputs.append(black_img)
all_batch_outputs.append(batch_outputs)
return all_batch_outputs
finally:
# Unbind before deleting
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
gl.glUseProgram(0)
if input_textures:
gl.glDeleteTextures(len(input_textures), input_textures)
if output_textures:
gl.glDeleteTextures(len(output_textures), output_textures)
if fbo is not None:
gl.glDeleteFramebuffers(1, [fbo])
if program is not None:
gl.glDeleteProgram(program)
class GLSLShader(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
image_template = io.Autogrow.TemplatePrefix(
io.Image.Input("image"),
prefix="image",
min=1,
max=MAX_IMAGES,
)
float_template = io.Autogrow.TemplatePrefix(
io.Float.Input("float", default=0.0),
prefix="u_float",
min=0,
max=MAX_UNIFORMS,
)
int_template = io.Autogrow.TemplatePrefix(
io.Int.Input("int", default=0),
prefix="u_int",
min=0,
max=MAX_UNIFORMS,
)
return io.Schema(
node_id="GLSLShader",
display_name="GLSL Shader",
category="image/shader",
description=(
f"Apply GLSL fragment shaders to images. "
f"Inputs: u_image0-{MAX_IMAGES-1} (sampler2D), u_resolution (vec2), "
f"u_float0-{MAX_UNIFORMS-1}, u_int0-{MAX_UNIFORMS-1}. "
f"Outputs: layout(location = 0-{MAX_OUTPUTS-1}) out vec4 fragColor0-{MAX_OUTPUTS-1}."
),
inputs=[
io.String.Input(
"fragment_shader",
default=DEFAULT_FRAGMENT_SHADER,
multiline=True,
tooltip="GLSL fragment shader source code (GLSL ES 3.00 / WebGL 2.0 compatible)",
),
io.DynamicCombo.Input(
"size_mode",
options=[
io.DynamicCombo.Option("from_input", []),
io.DynamicCombo.Option(
"custom",
[
io.Int.Input(
"width",
default=512,
min=1,
max=nodes.MAX_RESOLUTION,
),
io.Int.Input(
"height",
default=512,
min=1,
max=nodes.MAX_RESOLUTION,
),
],
),
],
tooltip="Output size: 'from_input' uses first input image dimensions, 'custom' allows manual size",
),
io.Autogrow.Input("images", template=image_template),
io.Autogrow.Input("floats", template=float_template),
io.Autogrow.Input("ints", template=int_template),
],
outputs=[
io.Image.Output(display_name="IMAGE0"),
io.Image.Output(display_name="IMAGE1"),
io.Image.Output(display_name="IMAGE2"),
io.Image.Output(display_name="IMAGE3"),
],
)
@classmethod
def execute(
cls,
fragment_shader: str,
size_mode: SizeModeInput,
images: io.Autogrow.Type,
floats: io.Autogrow.Type = None,
ints: io.Autogrow.Type = None,
**kwargs,
) -> io.NodeOutput:
image_list = [v for v in images.values() if v is not None]
float_list = (
[v if v is not None else 0.0 for v in floats.values()] if floats else []
)
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
if not image_list:
raise ValueError("At least one input image is required")
# Determine output dimensions
if size_mode["size_mode"] == "custom":
out_width = size_mode["width"]
out_height = size_mode["height"]
else:
out_height, out_width = image_list[0].shape[1:3]
batch_size = image_list[0].shape[0]
# Prepare batches
image_batches = []
for batch_idx in range(batch_size):
batch_images = [img_tensor[batch_idx].cpu().numpy().astype(np.float32) for img_tensor in image_list]
image_batches.append(batch_images)
all_batch_outputs = _render_shader_batch(
fragment_shader,
out_width,
out_height,
image_batches,
float_list,
int_list,
)
# Collect outputs into tensors
all_outputs = [[] for _ in range(MAX_OUTPUTS)]
for batch_outputs in all_batch_outputs:
for i, out_img in enumerate(batch_outputs):
all_outputs[i].append(torch.from_numpy(out_img))
output_tensors = [torch.stack(all_outputs[i], dim=0) for i in range(MAX_OUTPUTS)]
return io.NodeOutput(
*output_tensors,
ui=cls._build_ui_output(image_list, output_tensors[0]),
)
@classmethod
def _build_ui_output(
cls, image_list: list[torch.Tensor], output_batch: torch.Tensor
) -> dict[str, list]:
"""Build UI output with input and output images for client-side shader execution."""
combined_inputs = torch.cat(image_list, dim=0)
input_images_ui = ui.ImageSaveHelper.save_images(
combined_inputs,
filename_prefix="GLSLShader_input",
folder_type=io.FolderType.temp,
cls=None,
compress_level=1,
)
output_images_ui = ui.ImageSaveHelper.save_images(
output_batch,
filename_prefix="GLSLShader_output",
folder_type=io.FolderType.temp,
cls=None,
compress_level=1,
)
return {"input_images": input_images_ui, "images": output_images_ui}
class GLSLExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [GLSLShader]
async def comfy_entrypoint() -> GLSLExtension:
return GLSLExtension()

View File

@@ -23,8 +23,9 @@ class ImageCrop(IO.ComfyNode):
return IO.Schema(
node_id="ImageCrop",
search_aliases=["trim"],
display_name="Image Crop",
display_name="Image Crop (Deprecated)",
category="image/transform",
is_deprecated=True,
inputs=[
IO.Image.Input("image"),
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
@@ -47,6 +48,57 @@ class ImageCrop(IO.ComfyNode):
crop = execute # TODO: remove
class ImageCropV2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ImageCropV2",
search_aliases=["trim"],
display_name="Image Crop",
category="image/transform",
inputs=[
IO.Image.Input("image"),
IO.BoundingBox.Input("crop_region", component="ImageCrop"),
],
outputs=[IO.Image.Output()],
)
@classmethod
def execute(cls, image, crop_region) -> IO.NodeOutput:
x = crop_region.get("x", 0)
y = crop_region.get("y", 0)
width = crop_region.get("width", 512)
height = crop_region.get("height", 512)
x = min(x, image.shape[2] - 1)
y = min(y, image.shape[1] - 1)
to_x = width + x
to_y = height + y
img = image[:,y:to_y, x:to_x, :]
return IO.NodeOutput(img, ui=UI.PreviewImage(img))
class BoundingBox(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="PrimitiveBoundingBox",
display_name="Bounding Box",
category="utils/primitive",
inputs=[
IO.Int.Input("x", default=0, min=0, max=MAX_RESOLUTION),
IO.Int.Input("y", default=0, min=0, max=MAX_RESOLUTION),
IO.Int.Input("width", default=512, min=1, max=MAX_RESOLUTION),
IO.Int.Input("height", default=512, min=1, max=MAX_RESOLUTION),
],
outputs=[IO.BoundingBox.Output()],
)
@classmethod
def execute(cls, x, y, width, height) -> IO.NodeOutput:
return IO.NodeOutput({"x": x, "y": y, "width": width, "height": height})
class RepeatImageBatch(IO.ComfyNode):
@classmethod
def define_schema(cls):
@@ -632,6 +684,8 @@ class ImagesExtension(ComfyExtension):
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
ImageCrop,
ImageCropV2,
BoundingBox,
RepeatImageBatch,
ImageFromBatch,
ImageAddNoise,

99
comfy_extras/nodes_nag.py Normal file
View File

@@ -0,0 +1,99 @@
import torch
from comfy_api.latest import ComfyExtension, io
from typing_extensions import override
class NAGuidance(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="NAGuidance",
display_name="Normalized Attention Guidance",
description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.",
category="",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to apply NAG to."),
io.Float.Input("nag_scale", min=0.0, default=5.0, max=50.0, step=0.1, tooltip="The guidance scale factor. Higher values push further from the negative prompt."),
io.Float.Input("nag_alpha", min=0.0, default=0.5, max=1.0, step=0.01, tooltip="Blending factor for the normalized attention. 1.0 is full replacement, 0.0 is no effect."),
io.Float.Input("nag_tau", min=1.0, default=1.5, max=10.0, step=0.01),
# io.Float.Input("start_percent", min=0.0, default=0.0, max=1.0, step=0.01, tooltip="The relative sampling step to begin applying NAG."),
# io.Float.Input("end_percent", min=0.0, default=1.0, max=1.0, step=0.01, tooltip="The relative sampling step to stop applying NAG."),
],
outputs=[
io.Model.Output(tooltip="The patched model with NAG enabled."),
],
)
@classmethod
def execute(cls, model: io.Model.Type, nag_scale: float, nag_alpha: float, nag_tau: float) -> io.NodeOutput:
m = model.clone()
# sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_percent)
# sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_percent)
def nag_attention_output_patch(out, extra_options):
cond_or_uncond = extra_options.get("cond_or_uncond", None)
if cond_or_uncond is None:
return out
if not (1 in cond_or_uncond and 0 in cond_or_uncond):
return out
# sigma = extra_options.get("sigmas", None)
# if sigma is not None and len(sigma) > 0:
# sigma = sigma[0].item()
# if sigma > sigma_start or sigma < sigma_end:
# return out
img_slice = extra_options.get("img_slice", None)
if img_slice is not None:
orig_out = out
out = out[:, img_slice[0]:img_slice[1]] # only apply on img part
batch_size = out.shape[0]
half_size = batch_size // len(cond_or_uncond)
ind_neg = cond_or_uncond.index(1)
ind_pos = cond_or_uncond.index(0)
z_pos = out[half_size * ind_pos:half_size * (ind_pos + 1)]
z_neg = out[half_size * ind_neg:half_size * (ind_neg + 1)]
guided = z_pos * nag_scale - z_neg * (nag_scale - 1.0)
eps = 1e-6
norm_pos = torch.norm(z_pos, p=1, dim=-1, keepdim=True).clamp_min(eps)
norm_guided = torch.norm(guided, p=1, dim=-1, keepdim=True).clamp_min(eps)
ratio = norm_guided / norm_pos
scale_factor = torch.minimum(ratio, torch.full_like(ratio, nag_tau)) / ratio
guided_normalized = guided * scale_factor
z_final = guided_normalized * nag_alpha + z_pos * (1.0 - nag_alpha)
if img_slice is not None:
orig_out[half_size * ind_neg:half_size * (ind_neg + 1), img_slice[0]:img_slice[1]] = z_final
orig_out[half_size * ind_pos:half_size * (ind_pos + 1), img_slice[0]:img_slice[1]] = z_final
return orig_out
else:
out[half_size * ind_pos:half_size * (ind_pos + 1)] = z_final
return out
m.set_model_attn1_output_patch(nag_attention_output_patch)
m.disable_model_cfg1_optimization()
return io.NodeOutput(m)
class NagExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
NAGuidance,
]
async def comfy_entrypoint() -> NagExtension:
return NagExtension()

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.13.0"
__version__ = "0.14.1"

View File

@@ -53,3 +53,14 @@
# checkpoints: models/checkpoints
# gligen: models/gligen
# custom_nodes: path/custom_nodes
# all_model_folders: when set to true alongside base_path, automatically registers
# base_path/{folder_name} for every model folder type known to ComfyUI.
# custom_nodes is excluded. This keeps configs in sync as new model types are added.
# Can be combined with explicit entries to add extra paths for specific types.
#shared_models:
# base_path: /path/to/shared/models
# is_default: true
# all_model_folders: true

View File

@@ -2433,11 +2433,11 @@ async def init_builtin_extra_nodes():
"nodes_wanmove.py",
"nodes_image_compare.py",
"nodes_zimage.py",
"nodes_glsl.py",
"nodes_lora_debug.py",
"nodes_color.py",
"nodes_toolkit.py",
"nodes_replacements.py",
"nodes_nag.py",
]
import_failed = []

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.13.0"
version = "0.14.1"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.38.14
comfyui-workflow-templates==0.8.42
comfyui-frontend-package==1.39.14
comfyui-workflow-templates==0.8.43
comfyui-embedded-docs==0.4.1
torch
torchsde
@@ -30,6 +30,3 @@ kornia>=0.7.1
spandrel
pydantic~=2.0
pydantic-settings~=2.0
PyOpenGL
PyOpenGL-accelerate
glfw

View File

@@ -301,3 +301,147 @@ def test_load_extra_path_config_no_base_path(
actual_diffusion = folder_paths.folder_names_and_paths["diffusion_models"][0]
assert len(actual_diffusion) == 1, "Should have one path for 'diffusion_models'."
assert actual_diffusion[0] == os.path.abspath(expected_unet)
@patch("builtins.open", new_callable=mock_open, read_data="dummy yaml content")
@patch("yaml.safe_load")
def test_load_extra_path_config_all_model_folders(
mock_yaml_load, _mock_file, clear_folder_paths, monkeypatch, tmp_path
):
"""
Test that when a config group has all_model_folders: true, it registers
base_path/{folder_name} for every known model folder type (excluding custom_nodes).
"""
# Pre-populate the registry with a few folder types to simulate ComfyUI's defaults
folder_paths.folder_names_and_paths["checkpoints"] = ([str(tmp_path / "original" / "checkpoints")], {".safetensors"})
folder_paths.folder_names_and_paths["loras"] = ([str(tmp_path / "original" / "loras")], {".safetensors"})
folder_paths.folder_names_and_paths["vae"] = ([str(tmp_path / "original" / "vae")], {".safetensors"})
folder_paths.folder_names_and_paths["custom_nodes"] = ([str(tmp_path / "original" / "custom_nodes")], set())
abs_base = str(tmp_path / "shared_models")
config_data = {
"wildcard_group": {
"base_path": abs_base,
"is_default": True,
"all_model_folders": True,
}
}
mock_yaml_load.return_value = config_data
dummy_yaml_name = "dummy_wildcard.yaml"
def fake_abspath(path):
if path == dummy_yaml_name:
return os.path.join(str(tmp_path), dummy_yaml_name)
return path
def fake_dirname(path):
return str(tmp_path) if path.endswith(dummy_yaml_name) else os.path.dirname(path)
monkeypatch.setattr(os.path, "abspath", fake_abspath)
monkeypatch.setattr(os.path, "dirname", fake_dirname)
load_extra_path_config(dummy_yaml_name)
# Each pre-existing model folder type should now have the wildcard path prepended (is_default=True)
for folder_name in ["checkpoints", "loras", "vae"]:
paths = folder_paths.folder_names_and_paths[folder_name][0]
expected_wildcard_path = os.path.normpath(os.path.join(abs_base, folder_name))
assert paths[0] == expected_wildcard_path, \
f"Expected wildcard path at index 0 for '{folder_name}', got {paths[0]}"
assert len(paths) == 2, \
f"Expected 2 paths for '{folder_name}' (wildcard + original), got {len(paths)}"
# custom_nodes should NOT be included in wildcard expansion
custom_nodes_paths = folder_paths.folder_names_and_paths["custom_nodes"][0]
assert len(custom_nodes_paths) == 1, \
f"Expected custom_nodes to be untouched by wildcard, got {len(custom_nodes_paths)} paths"
@patch("builtins.open", new_callable=mock_open, read_data="dummy yaml content")
@patch("yaml.safe_load")
def test_load_extra_path_config_all_model_folders_with_explicit_entries(
mock_yaml_load, _mock_file, clear_folder_paths, monkeypatch, tmp_path
):
"""
Test that all_model_folders: true works alongside explicit folder entries.
The wildcard covers all types, and the explicit entry adds an additional path.
"""
folder_paths.folder_names_and_paths["checkpoints"] = ([str(tmp_path / "original" / "checkpoints")], {".safetensors"})
folder_paths.folder_names_and_paths["loras"] = ([str(tmp_path / "original" / "loras")], {".safetensors"})
abs_base = str(tmp_path / "shared_models")
config_data = {
"mixed_group": {
"base_path": abs_base,
"all_model_folders": True,
"checkpoints": "my_checkpoints",
}
}
mock_yaml_load.return_value = config_data
dummy_yaml_name = "dummy_mixed.yaml"
def fake_abspath(path):
if path == dummy_yaml_name:
return os.path.join(str(tmp_path), dummy_yaml_name)
return path
def fake_dirname(path):
return str(tmp_path) if path.endswith(dummy_yaml_name) else os.path.dirname(path)
monkeypatch.setattr(os.path, "abspath", fake_abspath)
monkeypatch.setattr(os.path, "dirname", fake_dirname)
load_extra_path_config(dummy_yaml_name)
# loras should have the wildcard path appended (no is_default)
lora_paths = folder_paths.folder_names_and_paths["loras"][0]
assert len(lora_paths) == 2
assert lora_paths[1] == os.path.normpath(os.path.join(abs_base, "loras"))
# checkpoints should have both the wildcard path AND the explicit path
checkpoint_paths = folder_paths.folder_names_and_paths["checkpoints"][0]
assert len(checkpoint_paths) == 3
assert checkpoint_paths[1] == os.path.normpath(os.path.join(abs_base, "checkpoints"))
assert checkpoint_paths[2] == os.path.normpath(os.path.join(abs_base, "my_checkpoints"))
@patch("builtins.open", new_callable=mock_open, read_data="dummy yaml content")
@patch("yaml.safe_load")
def test_load_extra_path_config_base_path_only_no_flag(
mock_yaml_load, _mock_file, clear_folder_paths, monkeypatch, tmp_path
):
"""
Test that base_path alone (without all_model_folders) does NOT trigger
wildcard expansion — it just has no effect without explicit folder entries.
"""
folder_paths.folder_names_and_paths["checkpoints"] = ([str(tmp_path / "original" / "checkpoints")], {".safetensors"})
folder_paths.folder_names_and_paths["loras"] = ([str(tmp_path / "original" / "loras")], {".safetensors"})
abs_base = str(tmp_path / "shared_models")
config_data = {
"base_only_group": {
"base_path": abs_base,
}
}
mock_yaml_load.return_value = config_data
dummy_yaml_name = "dummy_base_only.yaml"
def fake_abspath(path):
if path == dummy_yaml_name:
return os.path.join(str(tmp_path), dummy_yaml_name)
return path
def fake_dirname(path):
return str(tmp_path) if path.endswith(dummy_yaml_name) else os.path.dirname(path)
monkeypatch.setattr(os.path, "abspath", fake_abspath)
monkeypatch.setattr(os.path, "dirname", fake_dirname)
load_extra_path_config(dummy_yaml_name)
# Nothing should have changed — no wildcard, no explicit entries
assert len(folder_paths.folder_names_and_paths["checkpoints"][0]) == 1
assert len(folder_paths.folder_names_and_paths["loras"][0]) == 1

View File

@@ -20,6 +20,16 @@ def load_extra_path_config(yaml_path):
is_default = False
if "is_default" in conf:
is_default = conf.pop("is_default")
all_model_folders = False
if "all_model_folders" in conf:
all_model_folders = conf.pop("all_model_folders")
if all_model_folders and base_path:
for folder_name in list(folder_paths.folder_names_and_paths.keys()):
if folder_name == "custom_nodes":
continue
full_path = os.path.normpath(os.path.join(base_path, folder_name))
logging.info("Adding extra search path {} {}".format(folder_name, full_path))
folder_paths.add_model_folder_path(folder_name, full_path, is_default)
for x in conf:
for y in conf[x].split("\n"):
if len(y) == 0: