Compare commits

..

1 Commits

Author SHA1 Message Date
Jedrzej Kosinski
a145651cc0 Track custom node startup errors and expose via API endpoint
Store import and prestartup errors in NODE_STARTUP_ERRORS dict (nodes.py,
main.py) and add GET /custom_node_startup_errors endpoint (server.py) so
the frontend/Manager can distinguish failed imports from missing nodes.

Ref: ComfyUI-Launcher#303
Amp-Thread-ID: https://ampcode.com/threads/T-019d2346-6e6f-75e0-a97f-cdb6e26859f7
Co-authored-by: Amp <amp@ampcode.com>
2026-03-24 23:41:01 -07:00
28 changed files with 423 additions and 497422 deletions

View File

@@ -93,13 +93,12 @@ def compute_relative_filename(file_path: str) -> str | None:
def get_asset_category_and_relative_path(
file_path: str,
) -> tuple[Literal["input", "output", "temp", "models"], str]:
) -> tuple[Literal["input", "output", "models"], str]:
"""Determine which root category a file path belongs to.
Categories:
- 'input': under folder_paths.get_input_directory()
- 'output': under folder_paths.get_output_directory()
- 'temp': under folder_paths.get_temp_directory()
- 'models': under any base path from get_comfy_models_folders()
Returns:
@@ -130,12 +129,7 @@ def get_asset_category_and_relative_path(
if _check_is_within(fp_abs, output_base):
return "output", _compute_relative(fp_abs, output_base)
# 3) temp
temp_base = os.path.abspath(folder_paths.get_temp_directory())
if _check_is_within(fp_abs, temp_base):
return "temp", _compute_relative(fp_abs, temp_base)
# 4) models (check deepest matching base to avoid ambiguity)
# 3) models (check deepest matching base to avoid ambiguity)
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
@@ -152,7 +146,7 @@ def get_asset_category_and_relative_path(
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(
f"Path is not within input, output, temp, or configured model bases: {file_path}"
f"Path is not within input, output, or configured model bases: {file_path}"
)

View File

@@ -1,90 +0,0 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform float u_float0;
uniform float u_float1;
uniform float u_float2;
uniform float u_float3;
uniform float u_float4;
uniform float u_float5;
uniform float u_float6;
uniform float u_float7;
uniform float u_float8;
uniform bool u_bool0;
in vec2 v_texCoord;
out vec4 fragColor;
vec3 rgb2hsl(vec3 c) {
float maxC = max(c.r, max(c.g, c.b));
float minC = min(c.r, min(c.g, c.b));
float l = (maxC + minC) * 0.5;
if (maxC == minC) return vec3(0.0, 0.0, l);
float d = maxC - minC;
float s = l > 0.5 ? d / (2.0 - maxC - minC) : d / (maxC + minC);
float h;
if (maxC == c.r) {
h = (c.g - c.b) / d + (c.g < c.b ? 6.0 : 0.0);
} else if (maxC == c.g) {
h = (c.b - c.r) / d + 2.0;
} else {
h = (c.r - c.g) / d + 4.0;
}
h /= 6.0;
return vec3(h, s, l);
}
float hue2rgb(float p, float q, float t) {
if (t < 0.0) t += 1.0;
if (t > 1.0) t -= 1.0;
if (t < 1.0 / 6.0) return p + (q - p) * 6.0 * t;
if (t < 1.0 / 2.0) return q;
if (t < 2.0 / 3.0) return p + (q - p) * (2.0 / 3.0 - t) * 6.0;
return p;
}
vec3 hsl2rgb(vec3 hsl) {
float h = hsl.x, s = hsl.y, l = hsl.z;
if (s == 0.0) return vec3(l);
float q = l < 0.5 ? l * (1.0 + s) : l + s - l * s;
float p = 2.0 * l - q;
return vec3(
hue2rgb(p, q, h + 1.0 / 3.0),
hue2rgb(p, q, h),
hue2rgb(p, q, h - 1.0 / 3.0)
);
}
void main() {
vec4 tex = texture(u_image0, v_texCoord);
vec3 color = tex.rgb;
vec3 shadows = vec3(u_float0, u_float1, u_float2) * 0.01;
vec3 midtones = vec3(u_float3, u_float4, u_float5) * 0.01;
vec3 highlights = vec3(u_float6, u_float7, u_float8) * 0.01;
float maxC = max(color.r, max(color.g, color.b));
float minC = min(color.r, min(color.g, color.b));
float lightness = (maxC + minC) * 0.5;
// GIMP weight curves: linear ramps with constants a=0.25, b=0.333, scale=0.7
const float a = 0.25;
const float b = 0.333;
const float scale = 0.7;
float sw = clamp((lightness - b) / -a + 0.5, 0.0, 1.0) * scale;
float mw = clamp((lightness - b) / a + 0.5, 0.0, 1.0) *
clamp((lightness + b - 1.0) / -a + 0.5, 0.0, 1.0) * scale;
float hw = clamp((lightness + b - 1.0) / a + 0.5, 0.0, 1.0) * scale;
color += sw * shadows + mw * midtones + hw * highlights;
if (u_bool0) {
vec3 hsl = rgb2hsl(clamp(color, 0.0, 1.0));
hsl.z = lightness;
color = hsl2rgb(hsl);
}
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
}

View File

@@ -1,46 +0,0 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform sampler2D u_curve0; // RGB master curve (256x1 LUT)
uniform sampler2D u_curve1; // Red channel curve
uniform sampler2D u_curve2; // Green channel curve
uniform sampler2D u_curve3; // Blue channel curve
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
// GIMP-compatible curve lookup with manual linear interpolation.
// Matches gimp_curve_map_value_inline() from gimpcurve-map.c:
// index = value * (n_samples - 1)
// f = fract(index)
// result = (1-f) * samples[floor] + f * samples[ceil]
//
// Uses texelFetch (NEAREST) to avoid GPU half-texel offset issues
// that occur with texture() + GL_LINEAR on small 256x1 LUTs.
float applyCurve(sampler2D curve, float value) {
value = clamp(value, 0.0, 1.0);
float pos = value * 255.0;
int lo = int(floor(pos));
int hi = min(lo + 1, 255);
float f = pos - float(lo);
float a = texelFetch(curve, ivec2(lo, 0), 0).r;
float b = texelFetch(curve, ivec2(hi, 0), 0).r;
return a + f * (b - a);
}
void main() {
vec4 color = texture(u_image0, v_texCoord);
// GIMP order: per-channel curves first, then RGB master curve.
// See gimp_curve_map_pixels() default case in gimpcurve-map.c:
// dest = colors_curve( channel_curve( src ) )
color.r = applyCurve(u_curve0, applyCurve(u_curve1, color.r));
color.g = applyCurve(u_curve0, applyCurve(u_curve2, color.g));
color.b = applyCurve(u_curve0, applyCurve(u_curve3, color.b));
fragColor0 = vec4(color.rgb, color.a);
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -928,7 +928,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
weight = state_dict.pop(weight_key, None)
if weight is None:
logging.warning(f"Missing weight for layer {layer_name}")
self.weight = None
return
manually_loaded_keys = [weight_key]
@@ -1035,9 +1034,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
if self.bias is not None:
sd["{}bias".format(prefix)] = self.bias
if self.weight is None:
return sd
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:

View File

@@ -61,7 +61,6 @@ import comfy.text_encoders.newbie
import comfy.text_encoders.anima
import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
import comfy.text_encoders.qwen35
import comfy.model_patcher
import comfy.lora
@@ -426,13 +425,13 @@ class CLIP:
def get_key_patches(self):
return self.patcher.get_key_patches()
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
self.cond_stage_model.reset_clip_options()
self.load_model(tokens)
self.cond_stage_model.set_clip_options({"layer": None})
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
@@ -1229,11 +1228,6 @@ class TEModel(Enum):
QWEN3_8B = 20
QWEN3_06B = 21
GEMMA_3_4B_VISION = 22
QWEN35_08B = 23
QWEN35_2B = 24
QWEN35_4B = 25
QWEN35_9B = 26
QWEN35_27B = 27
def detect_te_model(sd):
@@ -1273,17 +1267,6 @@ def detect_te_model(sd):
return TEModel.QWEN25_3B
if weight.shape[0] == 512:
return TEModel.QWEN25_7B
if "model.language_model.layers.0.linear_attn.A_log" in sd and "model.language_model.layers.0.input_layernorm.weight" in sd:
weight = sd['model.language_model.layers.0.input_layernorm.weight']
if weight.shape[0] == 1024:
return TEModel.QWEN35_08B
if weight.shape[0] == 2560:
return TEModel.QWEN35_4B
if weight.shape[0] == 4096:
return TEModel.QWEN35_9B
if weight.shape[0] == 5120:
return TEModel.QWEN35_27B
return TEModel.QWEN35_2B
if "model.layers.0.post_attention_layernorm.weight" in sd:
weight = sd['model.layers.0.post_attention_layernorm.weight']
if 'model.layers.0.self_attn.q_norm.weight' in sd:
@@ -1316,12 +1299,11 @@ def t5xxl_detect(clip_data):
return {}
def llama_detect(clip_data):
weight_names = ["model.layers.0.self_attn.k_proj.weight", "model.layers.0.linear_attn.in_proj_a.weight"]
weight_name = "model.layers.0.self_attn.k_proj.weight"
for sd in clip_data:
for weight_name in weight_names:
if weight_name in sd:
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
if weight_name in sd:
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
return {}
@@ -1449,11 +1431,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif te_model == TEModel.JINA_CLIP_2:
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
elif te_model in (TEModel.QWEN35_08B, TEModel.QWEN35_2B, TEModel.QWEN35_4B, TEModel.QWEN35_9B, TEModel.QWEN35_27B):
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
qwen35_type = {TEModel.QWEN35_08B: "qwen35_08b", TEModel.QWEN35_2B: "qwen35_2b", TEModel.QWEN35_4B: "qwen35_4b", TEModel.QWEN35_9B: "qwen35_9b", TEModel.QWEN35_27B: "qwen35_27b"}[te_model]
clip_target.clip = comfy.text_encoders.qwen35.te(**llama_detect(clip_data), model_type=qwen35_type)
clip_target.tokenizer = comfy.text_encoders.qwen35.tokenizer(model_type=qwen35_type)
elif te_model == TEModel.QWEN3_06B:
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer

View File

@@ -308,14 +308,14 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0):
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
if isinstance(tokens, dict):
tokens_only = next(iter(tokens.values())) # todo: get this better?
else:
tokens_only = tokens
tokens_only = [[t[0] for t in b] for b in tokens_only]
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=presence_penalty)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
def parse_parentheses(string):
result = []
@@ -740,5 +740,5 @@ class SD1ClipModel(torch.nn.Module):
def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd)
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)

View File

@@ -224,7 +224,7 @@ class Qwen3_8BConfig:
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass
@@ -655,17 +655,6 @@ class Llama2_(nn.Module):
if config.lm_head:
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def get_past_len(self, past_key_values):
return past_key_values[0][2]
def compute_freqs_cis(self, position_ids, device):
return precompute_freqs_cis(self.config.head_dim,
position_ids,
self.config.rope_theta,
self.config.rope_scale,
self.config.rope_dims,
device=device)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
if embeds is not None:
x = embeds
@@ -678,12 +667,17 @@ class Llama2_(nn.Module):
seq_len = x.shape[1]
past_len = 0
if past_key_values is not None and len(past_key_values) > 0:
past_len = self.get_past_len(past_key_values)
past_len = past_key_values[0][2]
if position_ids is None:
position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
freqs_cis = self.compute_freqs_cis(position_ids, x.device)
freqs_cis = precompute_freqs_cis(self.config.head_dim,
position_ids,
self.config.rope_theta,
self.config.rope_scale,
self.config.rope_dims,
device=x.device)
mask = None
if attention_mask is not None:
@@ -818,16 +812,9 @@ class BaseGenerate:
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
return x
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
model_config = self.model.config
past_key_values = []
for x in range(model_config.num_hidden_layers):
past_key_values.append((torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
return past_key_values
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0):
device = embeds.device
model_config = self.model.config
if stop_tokens is None:
stop_tokens = self.model.config.stop_tokens
@@ -842,8 +829,11 @@ class BaseGenerate:
if embeds.ndim == 2:
embeds = embeds.unsqueeze(0)
past_key_values = [] #kv_cache init
max_cache_len = embeds.shape[1] + max_length
past_key_values = self.init_kv_cache(embeds.shape[0], max_cache_len, device, execution_dtype)
for x in range(model_config.num_hidden_layers):
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
@@ -854,7 +844,7 @@ class BaseGenerate:
for step in tqdm(range(max_length), desc="Generating tokens"):
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
logits = self.logits(x)[:, -1]
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
token_id = next_token[0].item()
generated_token_ids.append(token_id)
@@ -866,7 +856,7 @@ class BaseGenerate:
return generated_token_ids
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True, presence_penalty=0.0):
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
if not do_sample or temperature == 0.0:
return torch.argmax(logits, dim=-1, keepdim=True)
@@ -877,11 +867,6 @@ class BaseGenerate:
for token_id in set(token_history):
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
if presence_penalty is not None and presence_penalty != 0.0:
for i in range(logits.shape[0]):
for token_id in set(token_history):
logits[i, token_id] -= presence_penalty
if temperature != 1.0:
logits = logits / temperature
@@ -912,9 +897,6 @@ class BaseGenerate:
class BaseQwen3:
def logits(self, x):
input = x[:, -1:]
if self.model.config.lm_head:
return self.model.lm_head(input)
module = self.model.embed_tokens
offload_stream = None

View File

@@ -91,11 +91,11 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
self.dtypes.add(dtype)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
tokens_only = [[t[0] for t in b] for b in tokens]
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
class DualLinearProjection(torch.nn.Module):
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
@@ -189,8 +189,8 @@ class LTXAVTEModel(torch.nn.Module):
return out.to(device=out_device, dtype=torch.float), pooled, extra
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty)
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
def load_sd(self, sd):
if "model.layers.47.self_attn.q_norm.weight" in sd:

View File

@@ -1,833 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass, field
import os
import math
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy import sd1_clip
import comfy.text_encoders.qwen_vl
from .llama import BaseLlama, BaseGenerate, Llama2_, MLP, RMSNorm, apply_rope
def _qwen35_layer_types(n):
return [("full_attention" if (i + 1) % 4 == 0 else "linear_attention") for i in range(n)]
@dataclass
class Qwen35Config:
vocab_size: int = 248320
hidden_size: int = 2048
intermediate_size: int = 6144
num_hidden_layers: int = 24
# Full attention params
num_attention_heads: int = 8
num_key_value_heads: int = 2
head_dim: int = 256
partial_rotary_factor: float = 0.25
# Linear attention (DeltaNet) params
linear_num_key_heads: int = 16
linear_num_value_heads: int = 16
linear_key_head_dim: int = 128
linear_value_head_dim: int = 128
conv_kernel_size: int = 4
# Shared params
max_position_embeddings: int = 32768
rms_norm_eps: float = 1e-6
rope_theta: float = 10000000.0
mrope_section: list = field(default_factory=lambda: [11, 11, 10])
layer_types: list = field(default_factory=lambda: _qwen35_layer_types(24))
rms_norm_add: bool = True
mlp_activation: str = "silu"
qkv_bias: bool = False
final_norm: bool = True
lm_head: bool = False
stop_tokens: list = field(default_factory=lambda: [248044, 248046])
# These are needed for BaseLlama/BaseGenerate compatibility but unused directly
transformer_type: str = "qwen35_2b"
rope_dims: list = None
rope_scale: float = None
QWEN35_VISION_DEFAULTS = dict(hidden_size=1024, num_heads=16, intermediate_size=4096, depth=24, patch_size=16, temporal_patch_size=2, in_channels=3, spatial_merge_size=2, num_position_embeddings=2304)
QWEN35_MODELS = {
"qwen35_08b": dict(hidden_size=1024, intermediate_size=3584, vision=dict(hidden_size=768, num_heads=12, intermediate_size=3072, depth=12)),
"qwen35_2b": dict(hidden_size=2048, intermediate_size=6144, num_hidden_layers=24, num_attention_heads=8, num_key_value_heads=2, linear_num_value_heads=16),
"qwen35_4b": dict(hidden_size=2560, intermediate_size=9216, num_hidden_layers=32, num_attention_heads=16, num_key_value_heads=4, linear_num_value_heads=32),
"qwen35_9b": dict(hidden_size=4096, intermediate_size=12288, num_hidden_layers=32, num_attention_heads=16, num_key_value_heads=4, linear_num_value_heads=32, lm_head=True, vision=dict(hidden_size=1152, intermediate_size=4304, depth=27)),
"qwen35_27b": dict(hidden_size=5120, intermediate_size=17408, num_hidden_layers=64, num_attention_heads=24, num_key_value_heads=4, linear_num_value_heads=48, lm_head=True, vision=dict(hidden_size=1152, intermediate_size=4304, depth=27)),
}
def _make_config(model_type, config_dict={}):
overrides = QWEN35_MODELS.get(model_type, {}).copy()
overrides.pop("vision", None)
if "num_hidden_layers" in overrides:
overrides["layer_types"] = _qwen35_layer_types(overrides["num_hidden_layers"])
overrides.update(config_dict)
return Qwen35Config(**overrides)
class RMSNormGated(RMSNorm):
def forward(self, x, gate):
return super().forward(x) * F.silu(gate.to(x.dtype))
def torch_chunk_gated_delta_rule(query, key, value, g, beta, chunk_size=64, initial_state=None, output_final_state=False):
initial_dtype = query.dtype
query = F.normalize(query, dim=-1)
key = F.normalize(key, dim=-1)
query, key, value, beta, g = [x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)]
batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
query = F.pad(query, (0, 0, 0, pad_size))
key = F.pad(key, (0, 0, 0, pad_size))
value = F.pad(value, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
total_sequence_length = sequence_length + pad_size
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale
v_beta = value * beta.unsqueeze(-1)
k_beta = key * beta.unsqueeze(-1)
query, key, value, k_beta, v_beta = [x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
g = g.cumsum(dim=-1)
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
for i in range(1, chunk_size):
row = attn[..., i, :i].clone()
sub = attn[..., :i, :i].clone()
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
value = attn @ v_beta
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
last_recurrent_state = (
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
if initial_state is None
else initial_state.to(value)
)
core_attn_out = torch.zeros_like(value)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
for i in range(0, total_sequence_length // chunk_size):
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
v_new = v_i - v_prime
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
core_attn_out[:, :, i] = attn_inter + attn @ v_new
last_recurrent_state = (
last_recurrent_state * g[:, :, i, -1, None, None].exp()
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
)
if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
core_attn_out = core_attn_out[:, :, :sequence_length]
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state
def torch_causal_conv1d_update(x, conv_state, weight, bias=None):
# conv_state: [B, channels, kernel_size-1], x: [B, channels, 1]
# weight: [channels, kernel_size]
state_len = conv_state.shape[-1]
combined = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # [B, channels, kernel_size]
conv_state.copy_(combined[:, :, -state_len:])
out = (combined * weight).sum(dim=-1, keepdim=True) # [B, channels, 1]
if bias is not None:
out = out + bias.unsqueeze(0).unsqueeze(-1)
return F.silu(out).to(x.dtype)
# GatedDeltaNet - Linear Attention Layer
class GatedDeltaNet(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
hidden = config.hidden_size
self.num_key_heads = config.linear_num_key_heads
self.num_value_heads = config.linear_num_value_heads
self.key_head_dim = config.linear_key_head_dim
self.value_head_dim = config.linear_value_head_dim
self.conv_kernel_size = config.conv_kernel_size
key_dim = self.num_key_heads * self.key_head_dim
value_dim = self.num_value_heads * self.value_head_dim
self.key_dim = key_dim
self.value_dim = value_dim
conv_dim = key_dim * 2 + value_dim
self.in_proj_qkv = ops.Linear(hidden, conv_dim, bias=False, device=device, dtype=dtype)
self.in_proj_z = ops.Linear(hidden, value_dim, bias=False, device=device, dtype=dtype)
self.in_proj_b = ops.Linear(hidden, self.num_value_heads, bias=False, device=device, dtype=dtype)
self.in_proj_a = ops.Linear(hidden, self.num_value_heads, bias=False, device=device, dtype=dtype)
self.out_proj = ops.Linear(value_dim, hidden, bias=False, device=device, dtype=dtype)
self.dt_bias = nn.Parameter(torch.empty(self.num_value_heads, device=device, dtype=dtype))
self.A_log = nn.Parameter(torch.empty(self.num_value_heads, device=device, dtype=dtype))
self.conv1d = ops.Conv1d(in_channels=conv_dim, out_channels=conv_dim, bias=False, kernel_size=self.conv_kernel_size,
groups=conv_dim, padding=self.conv_kernel_size - 1, device=device, dtype=dtype)
self.norm = RMSNormGated(self.value_head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype)
def forward(self, x, past_key_value=None, **kwargs):
batch_size, seq_len, _ = x.shape
use_recurrent = (
past_key_value is not None
and past_key_value[2] > 0
and seq_len == 1
)
# Projections (shared)
mixed_qkv = self.in_proj_qkv(x).transpose(1, 2) # [B, conv_dim, seq_len]
z = self.in_proj_z(x)
b = self.in_proj_b(x)
a = self.in_proj_a(x)
# Conv1d
if use_recurrent:
recurrent_state, conv_state, step_index = past_key_value
conv_weight = comfy.model_management.cast_to_device(self.conv1d.weight, mixed_qkv.device, mixed_qkv.dtype).squeeze(1)
conv_bias = comfy.model_management.cast_to_device(self.conv1d.bias, mixed_qkv.device, mixed_qkv.dtype) if self.conv1d.bias is not None else None
mixed_qkv = torch_causal_conv1d_update(mixed_qkv, conv_state, conv_weight, conv_bias)
else:
if past_key_value is not None:
recurrent_state, conv_state, step_index = past_key_value
conv_state_init = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
conv_state.copy_(conv_state_init[:, :, -conv_state.shape[-1]:])
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
# Split QKV and compute beta/g
mixed_qkv = mixed_qkv.transpose(1, 2) # [B, seq_len, conv_dim]
query, key, value = mixed_qkv.split([self.key_dim, self.key_dim, self.value_dim], dim=-1)
beta = b.sigmoid()
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias.float())
# Delta rule
if use_recurrent:
# single-token path: work in [B, heads, dim] without seq dim
query = query.reshape(batch_size, self.num_key_heads, self.key_head_dim)
key = key.reshape(batch_size, self.num_key_heads, self.key_head_dim)
value = value.reshape(batch_size, self.num_value_heads, self.value_head_dim)
if self.num_value_heads != self.num_key_heads:
rep = self.num_value_heads // self.num_key_heads
query = query.repeat_interleave(rep, dim=1)
key = key.repeat_interleave(rep, dim=1)
scale = self.key_head_dim ** -0.5
q = F.normalize(query.float(), dim=-1) * scale
k = F.normalize(key.float(), dim=-1)
v = value.float()
beta_t = beta.reshape(batch_size, -1)
g_t = g.reshape(batch_size, -1).exp()
# In-place state update: [B, heads, k_dim, v_dim]
recurrent_state.mul_(g_t[:, :, None, None])
kv_mem = torch.einsum('bhk,bhkv->bhv', k, recurrent_state)
delta = (v - kv_mem) * beta_t[:, :, None]
recurrent_state.add_(k.unsqueeze(-1) * delta.unsqueeze(-2))
core_attn_out = torch.einsum('bhk,bhkv->bhv', q, recurrent_state)
core_attn_out = core_attn_out.to(x.dtype).unsqueeze(1)
present_key_value = (recurrent_state, conv_state, step_index + 1)
else:
query = query.reshape(batch_size, seq_len, -1, self.key_head_dim)
key = key.reshape(batch_size, seq_len, -1, self.key_head_dim)
value = value.reshape(batch_size, seq_len, -1, self.value_head_dim)
if self.num_value_heads != self.num_key_heads:
rep = self.num_value_heads // self.num_key_heads
query = query.repeat_interleave(rep, dim=2)
key = key.repeat_interleave(rep, dim=2)
core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule(
query, key, value, g=g, beta=beta,
initial_state=None,
output_final_state=past_key_value is not None,
)
present_key_value = None
if past_key_value is not None:
if last_recurrent_state is not None:
recurrent_state.copy_(last_recurrent_state.to(recurrent_state.dtype))
present_key_value = (recurrent_state, conv_state, step_index + seq_len)
# Gated norm + output projection (shared)
core_attn_out = self.norm(core_attn_out.reshape(-1, self.value_head_dim), z.reshape(-1, self.value_head_dim))
output = self.out_proj(core_attn_out.reshape(batch_size, seq_len, -1))
return output, present_key_value
# GatedAttention - Full Attention with output gating
def precompute_partial_rope(head_dim, rotary_dim, position_ids, theta, device=None, mrope_section=None):
"""Compute RoPE frequencies for partial rotary embeddings."""
theta_numerator = torch.arange(0, rotary_dim, 2, device=device).float()
inv_freq = 1.0 / (theta ** (theta_numerator / rotary_dim))
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if mrope_section is not None and position_ids.shape[0] == 3:
mrope_section_2 = [s * 2 for s in mrope_section]
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section_2, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section_2, dim=-1))], dim=-1).unsqueeze(0)
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
sin_split = sin.shape[-1] // 2
return (cos, sin[..., :sin_split], -sin[..., sin_split:])
def apply_partial_rope(xq, xk, freqs_cis, rotary_dim):
"""Apply RoPE to only the first rotary_dim dimensions."""
xq_rot = xq[..., :rotary_dim]
xq_pass = xq[..., rotary_dim:]
xk_rot = xk[..., :rotary_dim]
xk_pass = xk[..., rotary_dim:]
xq_rot, xk_rot = apply_rope(xq_rot, xk_rot, freqs_cis)
xq = torch.cat([xq_rot, xq_pass], dim=-1)
xk = torch.cat([xk_rot, xk_pass], dim=-1)
return xq, xk
class GatedAttention(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.hidden_size = config.hidden_size
self.inner_size = self.num_heads * self.head_dim
self.rotary_dim = int(self.head_dim * config.partial_rotary_factor)
# q_proj outputs 2x: query + gate
self.q_proj = ops.Linear(config.hidden_size, self.inner_size * 2, bias=config.qkv_bias, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
# QK norms with (1+weight) scaling
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, past_key_value=None):
batch_size, seq_length, _ = x.shape
# Project Q (with gate), K, V
qg = self.q_proj(x)
# Split into query and gate: each is [B, seq, inner_size]
qg = qg.view(batch_size, seq_length, self.num_heads, self.head_dim * 2)
xq, gate = qg[..., :self.head_dim], qg[..., self.head_dim:]
gate = gate.reshape(batch_size, seq_length, -1) # [B, seq, inner_size]
xk = self.k_proj(x)
xv = self.v_proj(x)
xq = self.q_norm(xq).transpose(1, 2) # [B, heads, seq, head_dim]
xk = self.k_norm(xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim)).transpose(1, 2)
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
# Apply partial RoPE
xq, xk = apply_partial_rope(xq, xk, freqs_cis, self.rotary_dim)
# KV cache
present_key_value = None
if past_key_value is not None:
past_key, past_value, index = past_key_value
num_tokens = xk.shape[2]
if past_key.shape[2] >= (index + num_tokens):
past_key[:, :, index:index + num_tokens] = xk
past_value[:, :, index:index + num_tokens] = xv
xk = past_key[:, :, :index + num_tokens]
xv = past_value[:, :, :index + num_tokens]
present_key_value = (past_key, past_value, index + num_tokens)
else:
if index > 0:
xk = torch.cat((past_key[:, :, :index], xk), dim=2)
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
present_key_value = (xk, xv, index + num_tokens)
# Expand KV heads for GQA
if self.num_heads != self.num_kv_heads:
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
output = output * gate.sigmoid()
return self.o_proj(output), present_key_value
# Hybrid Transformer Block
class Qwen35TransformerBlock(nn.Module):
def __init__(self, config, index, device=None, dtype=None, ops=None):
super().__init__()
self.layer_type = config.layer_types[index]
if self.layer_type == "linear_attention":
self.linear_attn = GatedDeltaNet(config, device=device, dtype=dtype, ops=ops)
else:
self.self_attn = GatedAttention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, past_key_value=None):
if self.layer_type == "linear_attention":
h, present_key_value = self.linear_attn(self.input_layernorm(x), attention_mask=attention_mask, past_key_value=past_key_value)
else:
h, present_key_value = self.self_attn(self.input_layernorm(x), attention_mask=attention_mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, past_key_value=past_key_value)
x = x + h
x = x + self.mlp(self.post_attention_layernorm(x))
return x, present_key_value
# Qwen35 Transformer Backbone
class Qwen35Transformer(Llama2_):
def __init__(self, config, device=None, dtype=None, ops=None):
nn.Module.__init__(self)
self.config = config
self.vocab_size = config.vocab_size
self.normalize_in = False
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList([
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)
for i in range(config.num_hidden_layers)
])
if config.final_norm:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
else:
self.norm = None
if config.lm_head:
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def get_past_len(self, past_key_values):
for i, layer in enumerate(self.layers):
if layer.layer_type == "full_attention":
if len(past_key_values) > i:
return past_key_values[i][2]
break
return 0
def compute_freqs_cis(self, position_ids, device):
rotary_dim = int(self.config.head_dim * self.config.partial_rotary_factor)
return precompute_partial_rope(
self.config.head_dim, rotary_dim, position_ids,
self.config.rope_theta, device=device,
mrope_section=self.config.mrope_section,
)
# Vision Encoder
class Qwen35VisionPatchEmbed(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.patch_size = config["patch_size"]
self.temporal_patch_size = config["temporal_patch_size"]
self.in_channels = config["in_channels"]
self.embed_dim = config["hidden_size"]
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
self.proj = ops.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True, device=device, dtype=dtype)
def forward(self, x):
target_dtype = self.proj.weight.dtype
x = x.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size)
return self.proj(x.to(target_dtype)).view(-1, self.embed_dim)
class Qwen35VisionMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size, device=None, dtype=None, ops=None):
super().__init__()
self.linear_fc1 = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
self.linear_fc2 = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype)
def forward(self, hidden_state):
return self.linear_fc2(F.gelu(self.linear_fc1(hidden_state), approximate="tanh"))
class Qwen35VisionRotaryEmbedding(nn.Module):
def __init__(self, dim, theta=10000.0):
super().__init__()
self.dim = dim
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen):
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(seq, self.inv_freq)
return freqs
class Qwen35VisionAttention(nn.Module):
def __init__(self, hidden_size, num_heads, device=None, dtype=None, ops=None):
super().__init__()
self.dim = hidden_size
self.num_heads = num_heads
self.head_dim = self.dim // self.num_heads
self.qkv = ops.Linear(self.dim, self.dim * 3, bias=True, device=device, dtype=dtype)
self.proj = ops.Linear(self.dim, self.dim, device=device, dtype=dtype)
def forward(self, x, cu_seqlens, position_embeddings, optimized_attention=None):
seq_length = x.shape[0]
query_states, key_states, value_states = (
self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)
query_states, key_states = apply_rope(query_states, key_states, position_embeddings)
# Process per-sequence attention
lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_splits = torch.split(query_states, lengths, dim=0)
k_splits = torch.split(key_states, lengths, dim=0)
v_splits = torch.split(value_states, lengths, dim=0)
attn_outputs = []
for q, k, v in zip(q_splits, k_splits, v_splits):
q = q.transpose(0, 1).unsqueeze(0)
k = k.transpose(0, 1).unsqueeze(0)
v = v.transpose(0, 1).unsqueeze(0)
attn_outputs.append(optimized_attention(q, k, v, self.num_heads, skip_reshape=True))
attn_output = torch.cat(attn_outputs, dim=1)
attn_output = attn_output.reshape(seq_length, -1)
return self.proj(attn_output)
class Qwen35VisionBlock(nn.Module):
def __init__(self, hidden_size, num_heads, intermediate_size, device=None, dtype=None, ops=None):
super().__init__()
self.norm1 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
self.norm2 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
self.attn = Qwen35VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops)
self.mlp = Qwen35VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
def forward(self, x, cu_seqlens, position_embeddings, optimized_attention=None):
x = x + self.attn(self.norm1(x), cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
return x + self.mlp(self.norm2(x))
class Qwen35VisionPatchMerger(nn.Module):
def __init__(self, hidden_size, spatial_merge_size, out_hidden_size, device=None, dtype=None, ops=None):
super().__init__()
merge_dim = hidden_size * (spatial_merge_size ** 2)
self.norm = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
self.linear_fc1 = ops.Linear(merge_dim, merge_dim, device=device, dtype=dtype)
self.linear_fc2 = ops.Linear(merge_dim, out_hidden_size, device=device, dtype=dtype)
self.merge_dim = merge_dim
def forward(self, x):
x = self.norm(x).view(-1, self.merge_dim)
return self.linear_fc2(F.gelu(self.linear_fc1(x)))
class Qwen35VisionModel(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.spatial_merge_size = config["spatial_merge_size"]
self.patch_size = config["patch_size"]
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
self.hidden_size = config["hidden_size"]
self.num_heads = config["num_heads"]
self.num_position_embeddings = config["num_position_embeddings"]
self.patch_embed = Qwen35VisionPatchEmbed(config, device=device, dtype=dtype, ops=ops)
self.pos_embed = ops.Embedding(self.num_position_embeddings, self.hidden_size, device=device, dtype=dtype)
self.num_grid_per_side = int(self.num_position_embeddings ** 0.5)
self.rotary_pos_emb = Qwen35VisionRotaryEmbedding(self.hidden_size // self.num_heads // 2)
self.blocks = nn.ModuleList([
Qwen35VisionBlock(self.hidden_size, self.num_heads, config["intermediate_size"], device=device, dtype=dtype, ops=ops)
for _ in range(config["depth"])
])
self.merger = Qwen35VisionPatchMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops)
def rot_pos_emb(self, grid_thw):
merge_size = self.spatial_merge_size
grid_thw_list = grid_thw.tolist()
max_hw = max(max(h, w) for _, h, w in grid_thw_list)
freq_table = self.rotary_pos_emb(max_hw)
device = freq_table.device
total_tokens = sum(int(t * h * w) for t, h, w in grid_thw_list)
pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
offset = 0
for num_frames, height, width in grid_thw_list:
num_frames, height, width = int(num_frames), int(height), int(width)
merged_h, merged_w = height // merge_size, width // merge_size
block_rows = torch.arange(merged_h, device=device)
block_cols = torch.arange(merged_w, device=device)
intra_row = torch.arange(merge_size, device=device)
intra_col = torch.arange(merge_size, device=device)
row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
coords = torch.stack((row_idx, col_idx), dim=-1)
if num_frames > 1:
coords = coords.repeat(num_frames, 1)
num_tokens = coords.shape[0]
pos_ids[offset:offset + num_tokens] = coords
offset += num_tokens
embeddings = freq_table[pos_ids]
embeddings = embeddings.flatten(1)
return embeddings
def fast_pos_embed_interpolate(self, grid_thw):
grid_thw_list = grid_thw.tolist()
grid_ts = [int(row[0]) for row in grid_thw_list]
grid_hs = [int(row[1]) for row in grid_thw_list]
grid_ws = [int(row[2]) for row in grid_thw_list]
device = self.pos_embed.weight.device
idx_list = [[] for _ in range(4)]
weight_list = [[] for _ in range(4)]
for t, h, w in grid_thw_list:
h, w = int(h), int(w)
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h, device=device)
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w, device=device)
h_idxs_floor = h_idxs.int()
w_idxs_floor = w_idxs.int()
h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
dh = h_idxs - h_idxs_floor
dw = w_idxs - w_idxs_floor
base_h = h_idxs_floor * self.num_grid_per_side
base_h_ceil = h_idxs_ceil * self.num_grid_per_side
indices = [
(base_h[None].T + w_idxs_floor[None]).flatten(),
(base_h[None].T + w_idxs_ceil[None]).flatten(),
(base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
(base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
]
weights = [
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
((1 - dh)[None].T * dw[None]).flatten(),
(dh[None].T * (1 - dw)[None]).flatten(),
(dh[None].T * dw[None]).flatten(),
]
for j in range(4):
idx_list[j].extend(indices[j].tolist())
weight_list[j].extend(weights[j].tolist())
idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device)
weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device)
pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None]
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
patch_pos_embeds_permute = []
merge_size = self.spatial_merge_size
for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
pos_embed = pos_embed.repeat(t, 1)
pos_embed = (
pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
.permute(0, 1, 3, 2, 4, 5)
.flatten(0, 4)
)
patch_pos_embeds_permute.append(pos_embed)
return torch.cat(patch_pos_embeds_permute)
def forward(self, x, grid_thw):
x = self.patch_embed(x)
pos_embeds = self.fast_pos_embed_interpolate(grid_thw).to(x.device)
x = x + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)
seq_len = x.shape[0]
x = x.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos().unsqueeze(-2)
sin = emb.sin().unsqueeze(-2)
sin_half = sin.shape[-1] // 2
position_embeddings = (cos, sin[..., :sin_half], -sin[..., sin_half:])
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
for blk in self.blocks:
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
merged = self.merger(x)
return merged
# Model Wrapper
class Qwen35(BaseLlama, BaseGenerate, torch.nn.Module):
model_type = "qwen35_2b"
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = _make_config(self.model_type, config_dict)
self.num_layers = config.num_hidden_layers
self.model = Qwen35Transformer(config, device=device, dtype=dtype, ops=operations)
vision_overrides = QWEN35_MODELS.get(self.model_type, {}).get("vision", {})
vision_config = {**QWEN35_VISION_DEFAULTS, **vision_overrides, "out_hidden_size": config.hidden_size}
self.visual = Qwen35VisionModel(vision_config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
def preprocess_embed(self, embed, device):
if embed["type"] == "image":
image, grid = comfy.text_encoders.qwen_vl.process_qwen2vl_images(embed["data"], patch_size=16)
return self.visual(image.to(device, dtype=torch.float32), grid), grid
return None, None
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[], past_key_values=None):
grid = None
position_ids = None
offset = 0
for e in embeds_info:
if e.get("type") == "image":
grid = e.get("extra", None)
start = e.get("index")
if position_ids is None:
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
end = e.get("size") + start
len_max = int(grid.max()) // 2
start_next = len_max + start
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
position_ids[0, start:end] = start + offset
max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
max_d = int(grid[0][2]) // 2
position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
offset += len_max - (end - start)
if grid is None:
position_ids = None
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids, past_key_values=past_key_values)
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
model_config = self.model.config
past_key_values = []
for i in range(model_config.num_hidden_layers):
if model_config.layer_types[i] == "linear_attention":
recurrent_state = torch.zeros(
[batch, model_config.linear_num_value_heads, model_config.linear_key_head_dim, model_config.linear_value_head_dim],
device=device, dtype=torch.float32
)
conv_dim = model_config.linear_num_key_heads * model_config.linear_key_head_dim * 2 + model_config.linear_num_value_heads * model_config.linear_value_head_dim
conv_state = torch.zeros(
[batch, conv_dim, model_config.conv_kernel_size - 1],
device=device, dtype=execution_dtype
)
past_key_values.append((recurrent_state, conv_state, 0))
else:
past_key_values.append((
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
0
))
return past_key_values
# Tokenizer and Text Encoder Wrappers
class Qwen35Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, embedding_size=2048, embedding_key="qwen35_2b"):
from transformers import Qwen2Tokenizer
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen35_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=Qwen2Tokenizer,
has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=248044, tokenizer_data=tokenizer_data)
class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, model_type="qwen35_2b"):
embedding_size = QWEN35_MODELS.get(model_type, {}).get("hidden_size", 2048)
tokenizer = lambda *a, **kw: Qwen35Tokenizer(*a, **kw, embedding_size=embedding_size, embedding_key=model_type)
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=model_type, tokenizer=tokenizer)
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
self.llama_template_images = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs):
image = kwargs.get("image", None)
if image is not None and len(images) == 0:
images = [image]
skip_template = False
if text.startswith('<|im_start|>'):
skip_template = True
if prevent_empty_text and text == '':
text = ' '
if skip_template:
llama_text = text
else:
if llama_template is None:
if len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
if not thinking:
llama_text += "<think>\n</think>\n"
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
key_name = next(iter(tokens))
embed_count = 0
qwen_tokens = tokens[key_name]
for r in qwen_tokens:
for i in range(len(r)):
if r[i][0] == 248056: # <|image_pad|>
if len(images) > embed_count:
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
embed_count += 1
return tokens
class Qwen35ClipModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}, model_type="qwen35_2b"):
class Qwen35_(Qwen35):
pass
Qwen35_.model_type = model_type
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={},
dtype=dtype, special_tokens={"pad": 248044}, layer_norm_hidden_state=False,
model_class=Qwen35_, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Qwen35TEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, model_type="qwen35_2b"):
clip_model = lambda **kw: Qwen35ClipModel(**kw, model_type=model_type)
super().__init__(device=device, dtype=dtype, name=model_type, clip_model=clip_model, model_options=model_options)
def tokenizer(model_type="qwen35_2b"):
class Qwen35ImageTokenizer_(Qwen35ImageTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type=model_type)
return Qwen35ImageTokenizer_
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen35_2b"):
class Qwen35TEModel_(Qwen35TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options, model_type=model_type)
return Qwen35TEModel_

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -145,20 +145,7 @@ class ReveImageCreateNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["upscale", "upscale.upscale_factor"],
),
expr="""
(
$factor := $lookup(widgets, "upscale.upscale_factor");
$fmt := {"approximate": true, "note": "(base)"};
widgets.upscale = "enabled" ? (
$factor = 4 ? {"type": "usd", "usd": 0.0762, "format": $fmt}
: $factor = 3 ? {"type": "usd", "usd": 0.0591, "format": $fmt}
: {"type": "usd", "usd": 0.0457, "format": $fmt}
) : {"type": "usd", "usd": 0.03432, "format": $fmt}
)
""",
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
),
)
@@ -238,21 +225,13 @@ class ReveImageEditNode(IO.ComfyNode):
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model", "upscale", "upscale.upscale_factor"],
widgets=["model"],
),
expr="""
(
$fmt := {"approximate": true, "note": "(base)"};
$isFast := $contains(widgets.model, "fast");
$enabled := widgets.upscale = "enabled";
$factor := $lookup(widgets, "upscale.upscale_factor");
$isFast
? {"type": "usd", "usd": 0.01001, "format": $fmt}
: $enabled ? (
$factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt}
: $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt}
: {"type": "usd", "usd": 0.0686, "format": $fmt}
) : {"type": "usd", "usd": 0.0572, "format": $fmt}
$base := $isFast ? 0.01001 : 0.0572;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
)
""",
),
@@ -348,21 +327,13 @@ class ReveImageRemixNode(IO.ComfyNode):
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model", "upscale", "upscale.upscale_factor"],
widgets=["model"],
),
expr="""
(
$fmt := {"approximate": true, "note": "(base)"};
$isFast := $contains(widgets.model, "fast");
$enabled := widgets.upscale = "enabled";
$factor := $lookup(widgets, "upscale.upscale_factor");
$isFast
? {"type": "usd", "usd": 0.01001, "format": $fmt}
: $enabled ? (
$factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt}
: $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt}
: {"type": "usd", "usd": 0.0686, "format": $fmt}
) : {"type": "usd", "usd": 0.0572, "format": $fmt}
$base := $isFast ? 0.01001 : 0.0572;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
)
""",
),

View File

@@ -38,7 +38,6 @@ from comfy_api_nodes.util import (
UPSCALER_MODELS_MAP = {
"Starlight (Astra) Fast": "slf-1",
"Starlight (Astra) Creative": "slc-1",
"Starlight Precise 2.5": "slp-2.5",
}

View File

@@ -1,67 +1,85 @@
import os
import sys
import re
import ctypes
import logging
import ctypes.util
import importlib.util
from typing import TypedDict
import numpy as np
import torch
import nodes
import comfy_angle
from comfy_api.latest import ComfyExtension, io, ui
from typing_extensions import override
from utils.install_util import get_missing_requirements_message
logger = logging.getLogger(__name__)
def _preload_angle():
egl_path = comfy_angle.get_egl_path()
gles_path = comfy_angle.get_glesv2_path()
def _check_opengl_availability():
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
logger.debug("_check_opengl_availability: starting")
missing = []
if sys.platform == "win32":
angle_dir = comfy_angle.get_lib_dir()
os.add_dll_directory(angle_dir)
os.environ["PATH"] = angle_dir + os.pathsep + os.environ.get("PATH", "")
# Check Python packages (using find_spec to avoid importing)
logger.debug("_check_opengl_availability: checking for glfw package")
if importlib.util.find_spec("glfw") is None:
missing.append("glfw")
mode = 0 if sys.platform == "win32" else ctypes.RTLD_GLOBAL
ctypes.CDLL(str(egl_path), mode=mode)
ctypes.CDLL(str(gles_path), mode=mode)
logger.debug("_check_opengl_availability: checking for OpenGL package")
if importlib.util.find_spec("OpenGL") is None:
missing.append("PyOpenGL")
if missing:
raise RuntimeError(
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
)
# On Linux without display, check if headless backends are available
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
if sys.platform.startswith("linux"):
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
if not has_display:
# Check for EGL or OSMesa libraries
logger.debug("_check_opengl_availability: checking for EGL library")
has_egl = ctypes.util.find_library("EGL")
logger.debug("_check_opengl_availability: checking for OSMesa library")
has_osmesa = ctypes.util.find_library("OSMesa")
# Error disabled for CI as it fails this check
# if not has_egl and not has_osmesa:
# raise RuntimeError(
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
# "See error below for installation instructions."
# )
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
logger.debug("_check_opengl_availability: completed")
# Pre-load ANGLE *before* any PyOpenGL import so that the EGL platform
# plugin picks up ANGLE's libEGL / libGLESv2 instead of system libs.
_preload_angle()
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
# Run early check at import time
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
_check_opengl_availability()
import OpenGL
OpenGL.USE_ACCELERATE = False
# OpenGL modules - initialized lazily when context is created
gl = None
glfw = None
EGL = None
def _patch_find_library():
"""PyOpenGL's EGL platform looks for 'EGL' and 'GLESv2' by short name
via ctypes.util.find_library, but ANGLE ships as 'libEGL' and
'libGLESv2'. Patch find_library to return the full ANGLE paths so
PyOpenGL loads the same libraries we pre-loaded."""
if sys.platform == "linux":
return
import ctypes.util
_orig = ctypes.util.find_library
def _patched(name):
if name == 'EGL':
return comfy_angle.get_egl_path()
if name == 'GLESv2':
return comfy_angle.get_glesv2_path()
return _orig(name)
ctypes.util.find_library = _patched
def _import_opengl():
"""Import OpenGL module. Called after context is created."""
global gl
if gl is None:
logger.debug("_import_opengl: importing OpenGL.GL")
import OpenGL.GL as _gl
gl = _gl
logger.debug("_import_opengl: import completed")
return gl
_patch_find_library()
from OpenGL import EGL
from OpenGL import GLES3 as gl
class SizeModeInput(TypedDict):
size_mode: str
width: int
@@ -69,9 +87,7 @@ class SizeModeInput(TypedDict):
MAX_IMAGES = 5 # u_image0-4
MAX_UNIFORMS = 20 # u_float0-19, u_int0-19
MAX_BOOLS = 10 # u_bool0-9
MAX_CURVES = 4 # u_curve0-3 (1D LUT textures)
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
# Vertex shader using gl_VertexID trick - no VBO needed.
@@ -84,7 +100,7 @@ MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
# (-1,-1)---(3,-1)
#
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
VERTEX_SHADER = """#version 300 es
VERTEX_SHADER = """#version 330 core
out vec2 v_texCoord;
void main() {
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
@@ -108,21 +124,14 @@ void main() {
"""
def _egl_attribs(*values):
"""Build an EGL_NONE-terminated EGLint attribute array."""
vals = list(values) + [EGL.EGL_NONE]
return (ctypes.c_int32 * len(vals))(*vals)
def _gl_str(name):
"""Get an OpenGL string parameter."""
v = gl.glGetString(name)
if not v:
return "Unknown"
if isinstance(v, bytes):
return v.decode(errors="replace")
return ctypes.string_at(v).decode(errors="replace")
def _convert_es_to_desktop(source: str) -> str:
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
# Remove any existing #version directive
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
# Remove precision qualifiers (not needed in desktop GLSL)
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
# Prepend desktop GLSL version
return "#version 330 core\n" + source
def _detect_output_count(source: str) -> int:
@@ -148,8 +157,163 @@ def _detect_pass_count(source: str) -> int:
return 1
def _init_glfw():
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
logger.debug("_init_glfw: starting")
# On macOS, glfw.init() must be called from main thread or it hangs forever
if sys.platform == "darwin":
logger.debug("_init_glfw: skipping on macOS")
raise RuntimeError("GLFW backend not supported on macOS")
logger.debug("_init_glfw: importing glfw module")
import glfw as _glfw
logger.debug("_init_glfw: calling glfw.init()")
if not _glfw.init():
raise RuntimeError("glfw.init() failed")
try:
logger.debug("_init_glfw: setting window hints")
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
logger.debug("_init_glfw: calling create_window()")
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
if not window:
raise RuntimeError("glfw.create_window() failed")
logger.debug("_init_glfw: calling make_context_current()")
_glfw.make_context_current(window)
logger.debug("_init_glfw: completed successfully")
return window, _glfw
except Exception:
logger.debug("_init_glfw: failed, terminating glfw")
_glfw.terminate()
raise
def _init_egl():
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
logger.debug("_init_egl: starting")
from OpenGL import EGL as _EGL
from OpenGL.EGL import (
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
eglTerminate, eglDestroyContext, eglDestroySurface,
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
)
logger.debug("_init_egl: imports completed")
display = None
context = None
surface = None
try:
logger.debug("_init_egl: calling eglGetDisplay()")
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
if display == _EGL.EGL_NO_DISPLAY:
raise RuntimeError("eglGetDisplay() failed")
logger.debug("_init_egl: calling eglInitialize()")
major, minor = _EGL.EGLint(), _EGL.EGLint()
if not eglInitialize(display, major, minor):
display = None # Not initialized, don't terminate
raise RuntimeError("eglInitialize() failed")
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
config_attribs = [
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
EGL_DEPTH_SIZE, 0, EGL_NONE
]
configs = (_EGL.EGLConfig * 1)()
num_configs = _EGL.EGLint()
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
raise RuntimeError("eglChooseConfig() failed")
config = configs[0]
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
if not eglBindAPI(EGL_OPENGL_API):
raise RuntimeError("eglBindAPI() failed")
logger.debug("_init_egl: calling eglCreateContext()")
context_attribs = [
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
EGL_NONE
]
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
if context == EGL_NO_CONTEXT:
raise RuntimeError("eglCreateContext() failed")
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
if surface == _EGL.EGL_NO_SURFACE:
raise RuntimeError("eglCreatePbufferSurface() failed")
logger.debug("_init_egl: calling eglMakeCurrent()")
if not eglMakeCurrent(display, surface, surface, context):
raise RuntimeError("eglMakeCurrent() failed")
logger.debug("_init_egl: completed successfully")
return display, context, surface, _EGL
except Exception:
logger.debug("_init_egl: failed, cleaning up")
# Clean up any resources on failure
if surface is not None:
eglDestroySurface(display, surface)
if context is not None:
eglDestroyContext(display, context)
if display is not None:
eglTerminate(display)
raise
def _init_osmesa():
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
import ctypes
logger.debug("_init_osmesa: starting")
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
logger.debug("_init_osmesa: importing OpenGL.osmesa")
from OpenGL import GL as _gl
from OpenGL.osmesa import (
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
OSMESA_RGBA,
)
logger.debug("_init_osmesa: imports completed")
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
if not ctx:
raise RuntimeError("OSMesaCreateContextExt() failed")
width, height = 64, 64
buffer = (ctypes.c_ubyte * (width * height * 4))()
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
OSMesaDestroyContext(ctx)
raise RuntimeError("OSMesaMakeCurrent() failed")
logger.debug("_init_osmesa: completed successfully")
return ctx, buffer
class GLContext:
"""Manages an OpenGL ES 3.0 context via EGL/ANGLE (singleton)."""
"""Manages OpenGL context and resources for shader execution.
Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
"""
_instance = None
_initialized = False
@@ -161,111 +325,131 @@ class GLContext:
def __init__(self):
if GLContext._initialized:
logger.debug("GLContext.__init__: already initialized, skipping")
return
logger.debug("GLContext.__init__: starting initialization")
global glfw, EGL
import time
start = time.perf_counter()
self._display = None
self._surface = None
self._context = None
self._backend = None
self._window = None
self._egl_display = None
self._egl_context = None
self._egl_surface = None
self._osmesa_ctx = None
self._osmesa_buffer = None
self._vao = None
# Try backends in order: GLFW → EGL → OSMesa
errors = []
logger.debug("GLContext.__init__: trying GLFW backend")
try:
self._display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY)
if not self._display:
raise RuntimeError("eglGetDisplay() returned no display")
self._window, glfw = _init_glfw()
self._backend = "glfw"
logger.debug("GLContext.__init__: GLFW backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
errors.append(("GLFW", e))
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
if not EGL.eglInitialize(self._display, ctypes.byref(major), ctypes.byref(minor)):
err = EGL.eglGetError()
self._display = None
raise RuntimeError(f"eglInitialize() failed (EGL error: 0x{err:04X})")
if self._backend is None:
logger.debug("GLContext.__init__: trying EGL backend")
try:
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
self._backend = "egl"
logger.debug("GLContext.__init__: EGL backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
errors.append(("EGL", e))
if not EGL.eglBindAPI(EGL.EGL_OPENGL_ES_API):
raise RuntimeError("eglBindAPI(EGL_OPENGL_ES_API) failed")
if self._backend is None:
logger.debug("GLContext.__init__: trying OSMesa backend")
try:
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
self._backend = "osmesa"
logger.debug("GLContext.__init__: OSMesa backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
errors.append(("OSMesa", e))
config = EGL.EGLConfig()
n_configs = ctypes.c_int32(0)
if not EGL.eglChooseConfig(
self._display,
_egl_attribs(
EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_ES3_BIT,
EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
EGL.EGL_RED_SIZE, 8, EGL.EGL_GREEN_SIZE, 8,
EGL.EGL_BLUE_SIZE, 8, EGL.EGL_ALPHA_SIZE, 8,
),
ctypes.byref(config), 1, ctypes.byref(n_configs),
) or n_configs.value == 0:
raise RuntimeError("eglChooseConfig() failed")
if self._backend is None:
if sys.platform == "win32":
platform_help = (
"Windows: Ensure GPU drivers are installed and display is available.\n"
" CPU-only/headless mode is not supported on Windows."
)
elif sys.platform == "darwin":
platform_help = (
"macOS: GLFW is not supported.\n"
" Install OSMesa via Homebrew: brew install mesa\n"
" Then: pip install PyOpenGL PyOpenGL-accelerate"
)
else:
platform_help = (
"Linux: Install one of these backends:\n"
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
" Headless (CPU): sudo apt install libosmesa6"
)
self._surface = EGL.eglCreatePbufferSurface(
self._display, config,
_egl_attribs(EGL.EGL_WIDTH, 64, EGL.EGL_HEIGHT, 64),
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
raise RuntimeError(
f"Failed to create OpenGL context.\n\n"
f"Backend errors:\n{error_details}\n\n"
f"{platform_help}"
)
if not self._surface:
raise RuntimeError("eglCreatePbufferSurface() failed")
self._context = EGL.eglCreateContext(
self._display, config, EGL.EGL_NO_CONTEXT,
_egl_attribs(EGL.EGL_CONTEXT_CLIENT_VERSION, 3),
)
if not self._context:
raise RuntimeError("eglCreateContext() failed")
# Now import OpenGL.GL (after context is current)
logger.debug("GLContext.__init__: importing OpenGL.GL")
_import_opengl()
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
raise RuntimeError("eglMakeCurrent() failed")
self._vao = gl.glGenVertexArrays(1)
gl.glBindVertexArray(self._vao)
except Exception:
self._cleanup()
raise
# Create VAO (required for core profile, but OSMesa may use compat profile)
logger.debug("GLContext.__init__: creating VAO")
try:
vao = gl.glGenVertexArrays(1)
gl.glBindVertexArray(vao)
self._vao = vao # Only store after successful bind
logger.debug("GLContext.__init__: VAO created successfully")
except Exception as e:
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
# OSMesa with older Mesa may not support VAOs
# Clean up if we created but couldn't bind
if vao:
try:
gl.glDeleteVertexArrays(1, [vao])
except Exception:
pass
elapsed = (time.perf_counter() - start) * 1000
renderer = _gl_str(gl.GL_RENDERER)
vendor = _gl_str(gl.GL_VENDOR)
version = _gl_str(gl.GL_VERSION)
# Log device info
renderer = gl.glGetString(gl.GL_RENDERER)
vendor = gl.glGetString(gl.GL_VENDOR)
version = gl.glGetString(gl.GL_VERSION)
renderer = renderer.decode() if renderer else "Unknown"
vendor = vendor.decode() if vendor else "Unknown"
version = version.decode() if version else "Unknown"
GLContext._initialized = True
logger.info(f"GLSL context initialized in {elapsed:.1f}ms - {renderer} ({vendor}), GL {version}")
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
def make_current(self):
EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context)
if self._backend == "glfw":
glfw.make_context_current(self._window)
elif self._backend == "egl":
from OpenGL.EGL import eglMakeCurrent
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
elif self._backend == "osmesa":
from OpenGL.osmesa import OSMesaMakeCurrent
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
if self._vao is not None:
gl.glBindVertexArray(self._vao)
def _cleanup(self):
if not self._display:
return
try:
if self._vao is not None:
gl.glDeleteVertexArrays(1, [self._vao])
self._vao = None
except Exception:
pass
try:
EGL.eglMakeCurrent(self._display, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
except Exception:
pass
try:
if self._context:
EGL.eglDestroyContext(self._display, self._context)
except Exception:
pass
try:
if self._surface:
EGL.eglDestroySurface(self._display, self._surface)
except Exception:
pass
try:
EGL.eglTerminate(self._display)
except Exception:
pass
self._display = None
def _compile_shader(source: str, shader_type: int) -> int:
"""Compile a shader and return its ID."""
@@ -273,10 +457,8 @@ def _compile_shader(source: str, shader_type: int) -> int:
gl.glShaderSource(shader, source)
gl.glCompileShader(shader)
if not gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS):
error = gl.glGetShaderInfoLog(shader)
if isinstance(error, bytes):
error = error.decode(errors="replace")
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
error = gl.glGetShaderInfoLog(shader).decode()
gl.glDeleteShader(shader)
raise RuntimeError(f"Shader compilation failed:\n{error}")
@@ -300,10 +482,8 @@ def _create_program(vertex_source: str, fragment_source: str) -> int:
gl.glDeleteShader(vertex_shader)
gl.glDeleteShader(fragment_shader)
if not gl.glGetProgramiv(program, gl.GL_LINK_STATUS):
error = gl.glGetProgramInfoLog(program)
if isinstance(error, bytes):
error = error.decode(errors="replace")
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
error = gl.glGetProgramInfoLog(program).decode()
gl.glDeleteProgram(program)
raise RuntimeError(f"Program linking failed:\n{error}")
@@ -317,8 +497,6 @@ def _render_shader_batch(
image_batches: list[list[np.ndarray]],
floats: list[float],
ints: list[int],
bools: list[bool] | None = None,
curves: list[np.ndarray] | None = None,
) -> list[list[np.ndarray]]:
"""
Render a fragment shader for multiple batches efficiently.
@@ -333,8 +511,6 @@ def _render_shader_batch(
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
floats: List of float uniforms
ints: List of int uniforms
bools: List of bool uniforms (passed as int 0/1 to GLSL bool uniforms)
curves: List of 1D LUT arrays (float32) of arbitrary size for u_curve0-N
Returns:
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
@@ -348,23 +524,20 @@ def _render_shader_batch(
ctx = GLContext()
ctx.make_current()
# Convert from GLSL ES to desktop GLSL 330
fragment_source = _convert_es_to_desktop(fragment_code)
# Detect how many outputs the shader actually uses
num_outputs = _detect_output_count(fragment_code)
# Detect multi-pass rendering
num_passes = _detect_pass_count(fragment_code)
if bools is None:
bools = []
if curves is None:
curves = []
# Track resources for cleanup
program = None
fbo = None
output_textures = []
input_textures = []
curve_textures = []
ping_pong_textures = []
ping_pong_fbos = []
@@ -373,9 +546,9 @@ def _render_shader_batch(
try:
# Compile shaders (once for all batches)
try:
program = _create_program(VERTEX_SHADER, fragment_code)
program = _create_program(VERTEX_SHADER, fragment_source)
except RuntimeError:
logger.error(f"Fragment shader:\n{fragment_code}")
logger.error(f"Fragment shader:\n{fragment_source}")
raise
gl.glUseProgram(program)
@@ -451,28 +624,6 @@ def _render_shader_batch(
if loc >= 0:
gl.glUniform1i(loc, v)
for i, v in enumerate(bools):
loc = gl.glGetUniformLocation(program, f"u_bool{i}")
if loc >= 0:
gl.glUniform1i(loc, 1 if v else 0)
# Create 1D LUT textures for curves (bound after image texture units)
for i, lut in enumerate(curves):
tex = gl.glGenTextures(1)
curve_textures.append(tex)
unit = MAX_IMAGES + i
gl.glActiveTexture(gl.GL_TEXTURE0 + unit)
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_R32F, len(lut), 1, 0, gl.GL_RED, gl.GL_FLOAT, lut)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
loc = gl.glGetUniformLocation(program, f"u_curve{i}")
if loc >= 0:
gl.glUniform1i(loc, unit)
# Get u_pass uniform location for multi-pass
pass_loc = gl.glGetUniformLocation(program, "u_pass")
@@ -538,13 +689,13 @@ def _render_shader_batch(
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
# Read back outputs for this batch
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
# (glGetTexImage is synchronous, implicitly waits for rendering)
batch_outputs = []
for i in range(num_outputs):
gl.glReadBuffer(gl.GL_COLOR_ATTACHMENT0 + i)
buf = np.empty((height, width, 4), dtype=np.float32)
gl.glReadPixels(0, 0, width, height, gl.GL_RGBA, gl.GL_FLOAT, buf)
batch_outputs.append(buf[::-1, :, :].copy())
for tex in output_textures:
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
batch_outputs.append(img[::-1, :, :].copy())
# Pad with black images for unused outputs
black_img = np.zeros((height, width, 4), dtype=np.float32)
@@ -565,18 +716,16 @@ def _render_shader_batch(
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
gl.glUseProgram(0)
if input_textures:
gl.glDeleteTextures(len(input_textures), input_textures)
if curve_textures:
gl.glDeleteTextures(len(curve_textures), curve_textures)
if output_textures:
gl.glDeleteTextures(len(output_textures), output_textures)
if ping_pong_textures:
gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
for tex in input_textures:
gl.glDeleteTextures(int(tex))
for tex in output_textures:
gl.glDeleteTextures(int(tex))
for tex in ping_pong_textures:
gl.glDeleteTextures(int(tex))
if fbo is not None:
gl.glDeleteFramebuffers(1, [fbo])
if ping_pong_fbos:
gl.glDeleteFramebuffers(len(ping_pong_fbos), ping_pong_fbos)
for pp_fbo in ping_pong_fbos:
gl.glDeleteFramebuffers(1, [pp_fbo])
if program is not None:
gl.glDeleteProgram(program)
@@ -605,20 +754,6 @@ class GLSLShader(io.ComfyNode):
max=MAX_UNIFORMS,
)
bool_template = io.Autogrow.TemplatePrefix(
io.Boolean.Input("bool", default=False),
prefix="u_bool",
min=0,
max=MAX_BOOLS,
)
curve_template = io.Autogrow.TemplatePrefix(
io.Curve.Input("curve"),
prefix="u_curve",
min=0,
max=MAX_CURVES,
)
return io.Schema(
node_id="GLSLShader",
display_name="GLSL Shader",
@@ -627,7 +762,6 @@ class GLSLShader(io.ComfyNode):
"Apply GLSL ES fragment shaders to images. "
"u_resolution (vec2) is always available."
),
is_experimental=True,
inputs=[
io.String.Input(
"fragment_shader",
@@ -662,8 +796,6 @@ class GLSLShader(io.ComfyNode):
io.Autogrow.Input("images", template=image_template, tooltip=f"Images are available as u_image0-{MAX_IMAGES-1} (sampler2D) in the shader code"),
io.Autogrow.Input("floats", template=float_template, tooltip=f"Floats are available as u_float0-{MAX_UNIFORMS-1} in the shader code"),
io.Autogrow.Input("ints", template=int_template, tooltip=f"Ints are available as u_int0-{MAX_UNIFORMS-1} in the shader code"),
io.Autogrow.Input("bools", template=bool_template, tooltip=f"Booleans are available as u_bool0-{MAX_BOOLS-1} (bool) in the shader code"),
io.Autogrow.Input("curves", template=curve_template, tooltip=f"Curves are available as u_curve0-{MAX_CURVES-1} (sampler2D, 1D LUT) in the shader code. Sample with texture(u_curve0, vec2(x, 0.5)).r"),
],
outputs=[
io.Image.Output(display_name="IMAGE0", tooltip="Available via layout(location = 0) out vec4 fragColor0 in the shader code"),
@@ -681,19 +813,13 @@ class GLSLShader(io.ComfyNode):
images: io.Autogrow.Type,
floats: io.Autogrow.Type = None,
ints: io.Autogrow.Type = None,
bools: io.Autogrow.Type = None,
curves: io.Autogrow.Type = None,
**kwargs,
) -> io.NodeOutput:
image_list = [v for v in images.values() if v is not None]
float_list = (
[v if v is not None else 0.0 for v in floats.values()] if floats else []
)
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
bool_list = [v if v is not None else False for v in bools.values()] if bools else []
curve_luts = [v.to_lut().astype(np.float32) for v in curves.values() if v is not None] if curves else []
if not image_list:
raise ValueError("At least one input image is required")
@@ -720,8 +846,6 @@ class GLSLShader(io.ComfyNode):
image_batches,
float_list,
int_list,
bool_list,
curve_luts,
)
# Collect outputs into tensors

View File

@@ -44,13 +44,8 @@ class NumberConvertNode(io.ComfyNode):
def execute(cls, value) -> io.NodeOutput:
if isinstance(value, bool):
float_val = 1.0 if value else 0.0
int_val = 1 if value else 0
elif isinstance(value, int):
elif isinstance(value, (int, float)):
float_val = float(value)
int_val = value
elif isinstance(value, float):
float_val = value
int_val = int(value)
elif isinstance(value, str):
text = value.strip()
if not text:
@@ -61,14 +56,6 @@ class NumberConvertNode(io.ComfyNode):
raise ValueError(
f"Cannot convert string to number: {value!r}"
) from None
if not math.isfinite(float_val):
raise ValueError(
f"Cannot convert non-finite value to number: {float_val}"
)
try:
int_val = int(text)
except ValueError:
int_val = int(float_val)
else:
raise TypeError(
f"Unsupported input type: {type(value).__name__}"
@@ -79,7 +66,7 @@ class NumberConvertNode(io.ComfyNode):
f"Cannot convert non-finite value to number: {float_val}"
)
return io.NodeOutput(float_val, int_val)
return io.NodeOutput(float_val, int(float_val))
class NumberConvertExtension(ComfyExtension):

View File

@@ -67,11 +67,11 @@ class Blend(io.ComfyNode):
def g(cls, x):
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
def gaussian_kernel(kernel_size: int, sigma: float, device=None, dtype=torch.float32):
def gaussian_kernel(kernel_size: int, sigma: float, device=None):
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij")
d = torch.sqrt(x * x + y * y)
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
return (g / g.sum()).to(dtype)
return g / g.sum()
class Blur(io.ComfyNode):
@classmethod
@@ -99,7 +99,7 @@ class Blur(io.ComfyNode):
batch_size, height, width, channels = image.shape
kernel_size = blur_radius * 2 + 1
kernel = gaussian_kernel(kernel_size, sigma, device=image.device, dtype=image.dtype).repeat(channels, 1, 1).unsqueeze(1)
kernel = gaussian_kernel(kernel_size, sigma, device=image.device).repeat(channels, 1, 1).unsqueeze(1)
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
@@ -200,7 +200,7 @@ class Sharpen(io.ComfyNode):
image = image.to(comfy.model_management.get_torch_device())
kernel_size = sharpen_radius * 2 + 1
kernel = gaussian_kernel(kernel_size, sigma, device=image.device, dtype=image.dtype) * -(alpha*10)
kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
kernel = kernel.to(dtype=image.dtype)
center = kernel_size // 2
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0

View File

@@ -15,7 +15,6 @@ class TextGenerate(io.ComfyNode):
io.Float.Input("min_p", default=0.05, min=0.0, max=1.0, step=0.01),
io.Float.Input("repetition_penalty", default=1.05, min=0.0, max=5.0, step=0.01),
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
io.Float.Input("presence_penalty", optional=True, default=0.0, min=0.0, max=5.0, step=0.01),
]
),
io.DynamicCombo.Option(
@@ -26,7 +25,7 @@ class TextGenerate(io.ComfyNode):
return io.Schema(
node_id="TextGenerate",
category="textgen",
category="textgen/",
search_aliases=["LLM", "gemma"],
inputs=[
io.Clip.Input("clip"),
@@ -34,7 +33,6 @@ class TextGenerate(io.ComfyNode):
io.Image.Input("image", optional=True),
io.Int.Input("max_length", default=256, min=1, max=2048),
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
],
outputs=[
io.String.Output(display_name="generated_text"),
@@ -42,9 +40,9 @@ class TextGenerate(io.ComfyNode):
)
@classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1, thinking=thinking)
tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1)
# Get sampling parameters from dynamic combo
do_sample = sampling_mode.get("sampling_mode") == "on"
@@ -54,7 +52,6 @@ class TextGenerate(io.ComfyNode):
min_p = sampling_mode.get("min_p", 0.0)
seed = sampling_mode.get("seed", None)
repetition_penalty = sampling_mode.get("repetition_penalty", 1.0)
presence_penalty = sampling_mode.get("presence_penalty", 0.0)
generated_ids = clip.generate(
tokens,
@@ -65,7 +62,6 @@ class TextGenerate(io.ComfyNode):
top_p=top_p,
min_p=min_p,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
seed=seed
)
@@ -160,12 +156,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
)
@classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
if image is None:
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
else:
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking)
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image)
class TextgenExtension(ComfyExtension):

View File

@@ -139,7 +139,16 @@ def execute_prestartup_script():
spec.loader.exec_module(module)
return True
except Exception as e:
import traceback
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
from nodes import NODE_STARTUP_ERRORS, get_module_name
node_module_name = get_module_name(os.path.dirname(script_path))
NODE_STARTUP_ERRORS[node_module_name] = {
"module_path": os.path.dirname(script_path),
"error": str(e),
"traceback": traceback.format_exc(),
"phase": "prestartup",
}
return False
node_paths = folder_paths.get_folder_paths("custom_nodes")

View File

@@ -1 +1 @@
comfyui_manager==4.1
comfyui_manager==4.1b8

View File

@@ -2181,6 +2181,9 @@ EXTENSION_WEB_DIRS = {}
# Dictionary of successfully loaded module names and associated directories.
LOADED_MODULE_DIRS = {}
# Dictionary of custom node startup errors, keyed by module name.
NODE_STARTUP_ERRORS: dict[str, dict] = {}
def get_module_name(module_path: str) -> str:
"""
@@ -2298,6 +2301,13 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
except Exception as e:
logging.warning(traceback.format_exc())
logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
module_name = get_module_name(module_path)
NODE_STARTUP_ERRORS[module_name] = {
"module_path": module_path,
"error": str(e),
"traceback": traceback.format_exc(),
"phase": "import",
}
return False
async def init_external_custom_nodes():

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.42.8
comfyui-workflow-templates==0.9.38
comfyui-workflow-templates==0.9.36
comfyui-embedded-docs==0.4.3
torch
torchsde
@@ -33,5 +33,5 @@ kornia>=0.7.1
spandrel
pydantic~=2.0
pydantic-settings~=2.0
PyOpenGL>=3.1.8
comfy-angle
PyOpenGL
glfw

View File

@@ -753,6 +753,10 @@ class PromptServer():
out[node_class] = node_info(node_class)
return web.json_response(out)
@routes.get("/custom_node_startup_errors")
async def get_custom_node_startup_errors(request):
return web.json_response(nodes.NODE_STARTUP_ERRORS)
@routes.get("/api/jobs")
async def get_jobs(request):
"""List all jobs with filtering, sorting, and pagination.

View File

@@ -1,81 +0,0 @@
"""Tests for path_utils asset category resolution."""
import os
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
from app.assets.services.path_utils import get_asset_category_and_relative_path
@pytest.fixture
def fake_dirs():
"""Create temporary input, output, and temp directories."""
with tempfile.TemporaryDirectory() as root:
root_path = Path(root)
input_dir = root_path / "input"
output_dir = root_path / "output"
temp_dir = root_path / "temp"
models_dir = root_path / "models" / "checkpoints"
for d in (input_dir, output_dir, temp_dir, models_dir):
d.mkdir(parents=True)
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
mock_fp.get_input_directory.return_value = str(input_dir)
mock_fp.get_output_directory.return_value = str(output_dir)
mock_fp.get_temp_directory.return_value = str(temp_dir)
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("checkpoints", [str(models_dir)])],
):
yield {
"input": input_dir,
"output": output_dir,
"temp": temp_dir,
"models": models_dir,
}
class TestGetAssetCategoryAndRelativePath:
def test_input_file(self, fake_dirs):
f = fake_dirs["input"] / "photo.png"
f.touch()
cat, rel = get_asset_category_and_relative_path(str(f))
assert cat == "input"
assert rel == "photo.png"
def test_output_file(self, fake_dirs):
f = fake_dirs["output"] / "result.png"
f.touch()
cat, rel = get_asset_category_and_relative_path(str(f))
assert cat == "output"
assert rel == "result.png"
def test_temp_file(self, fake_dirs):
"""Regression: temp files must be categorised, not raise ValueError."""
f = fake_dirs["temp"] / "GLSLShader_output_00004_.png"
f.touch()
cat, rel = get_asset_category_and_relative_path(str(f))
assert cat == "temp"
assert rel == "GLSLShader_output_00004_.png"
def test_temp_file_in_subfolder(self, fake_dirs):
sub = fake_dirs["temp"] / "sub"
sub.mkdir()
f = sub / "ComfyUI_temp_tczip_00004_.png"
f.touch()
cat, rel = get_asset_category_and_relative_path(str(f))
assert cat == "temp"
assert os.path.normpath(rel) == os.path.normpath("sub/ComfyUI_temp_tczip_00004_.png")
def test_model_file(self, fake_dirs):
f = fake_dirs["models"] / "model.safetensors"
f.touch()
cat, rel = get_asset_category_and_relative_path(str(f))
assert cat == "models"
def test_unknown_path_raises(self, fake_dirs):
with pytest.raises(ValueError, match="not within"):
get_asset_category_and_relative_path("/some/random/path.png")

View File

@@ -90,63 +90,6 @@ class TestNumberConvertExecute:
assert result[0] == 1000.0
assert result[1] == 1000
# --- Large number precision (string input) ---
def test_string_large_int_above_2_53(self):
"""Text-to-int must not lose precision for integers beyond 2^53."""
big = 2**53 + 1 # 9007199254740993
result = self._exec(str(big))
assert result[1] == big
def test_string_large_negative_int_above_2_53(self):
big = -(2**53 + 1)
result = self._exec(str(big))
assert result[1] == big
def test_string_very_large_int(self):
big = 2**63 + 42
result = self._exec(str(big))
assert result[1] == big
def test_string_large_int_float_output_is_float(self):
"""FLOAT output is still a float (may lose precision, but must be float type)."""
result = self._exec(str(2**53 + 1))
assert isinstance(result[0], float)
# --- Large number precision (int input) ---
def test_int_large_above_2_53(self):
"""Native int input must preserve its value in the INT output."""
big = 2**53 + 1
result = self._exec(big)
assert result[1] == big
def test_int_large_negative_above_2_53(self):
big = -(2**53 + 1)
result = self._exec(big)
assert result[1] == big
def test_int_very_large(self):
big = 2**100
result = self._exec(big)
assert result[1] == big
# --- String decimal / scientific notation fallback ---
def test_string_decimal_still_truncates(self):
"""Strings with decimal points fall back to int(float(...)) truncation."""
result = self._exec("3.7")
assert result[1] == 3
def test_string_negative_decimal_truncates(self):
result = self._exec("-2.9")
assert result[1] == -2
def test_string_scientific_large(self):
result = self._exec("1e18")
assert result[0] == 1e18
assert result[1] == 10**18
# --- STRING error paths ---
def test_empty_string_raises(self):

View File

@@ -24,7 +24,6 @@ def init_mime_types():
# Web types (used by server.py for static file serving)
mimetypes.add_type('application/javascript; charset=utf-8', '.js')
mimetypes.add_type('image/webp', '.webp')
mimetypes.add_type('image/svg+xml', '.svg')
# Model and data file types (used by asset scanning / metadata extraction)
mimetypes.add_type("application/safetensors", ".safetensors")