mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-29 08:57:29 +00:00
Compare commits
1 Commits
pysssss/an
...
feature/cu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a145651cc0 |
@@ -93,13 +93,12 @@ def compute_relative_filename(file_path: str) -> str | None:
|
||||
|
||||
def get_asset_category_and_relative_path(
|
||||
file_path: str,
|
||||
) -> tuple[Literal["input", "output", "temp", "models"], str]:
|
||||
) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Determine which root category a file path belongs to.
|
||||
|
||||
Categories:
|
||||
- 'input': under folder_paths.get_input_directory()
|
||||
- 'output': under folder_paths.get_output_directory()
|
||||
- 'temp': under folder_paths.get_temp_directory()
|
||||
- 'models': under any base path from get_comfy_models_folders()
|
||||
|
||||
Returns:
|
||||
@@ -130,12 +129,7 @@ def get_asset_category_and_relative_path(
|
||||
if _check_is_within(fp_abs, output_base):
|
||||
return "output", _compute_relative(fp_abs, output_base)
|
||||
|
||||
# 3) temp
|
||||
temp_base = os.path.abspath(folder_paths.get_temp_directory())
|
||||
if _check_is_within(fp_abs, temp_base):
|
||||
return "temp", _compute_relative(fp_abs, temp_base)
|
||||
|
||||
# 4) models (check deepest matching base to avoid ambiguity)
|
||||
# 3) models (check deepest matching base to avoid ambiguity)
|
||||
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
@@ -152,7 +146,7 @@ def get_asset_category_and_relative_path(
|
||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
|
||||
raise ValueError(
|
||||
f"Path is not within input, output, temp, or configured model bases: {file_path}"
|
||||
f"Path is not within input, output, or configured model bases: {file_path}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0;
|
||||
uniform float u_float1;
|
||||
uniform float u_float2;
|
||||
uniform float u_float3;
|
||||
uniform float u_float4;
|
||||
uniform float u_float5;
|
||||
uniform float u_float6;
|
||||
uniform float u_float7;
|
||||
uniform float u_float8;
|
||||
uniform bool u_bool0;
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
vec3 rgb2hsl(vec3 c) {
|
||||
float maxC = max(c.r, max(c.g, c.b));
|
||||
float minC = min(c.r, min(c.g, c.b));
|
||||
float l = (maxC + minC) * 0.5;
|
||||
if (maxC == minC) return vec3(0.0, 0.0, l);
|
||||
float d = maxC - minC;
|
||||
float s = l > 0.5 ? d / (2.0 - maxC - minC) : d / (maxC + minC);
|
||||
float h;
|
||||
if (maxC == c.r) {
|
||||
h = (c.g - c.b) / d + (c.g < c.b ? 6.0 : 0.0);
|
||||
} else if (maxC == c.g) {
|
||||
h = (c.b - c.r) / d + 2.0;
|
||||
} else {
|
||||
h = (c.r - c.g) / d + 4.0;
|
||||
}
|
||||
h /= 6.0;
|
||||
return vec3(h, s, l);
|
||||
}
|
||||
|
||||
float hue2rgb(float p, float q, float t) {
|
||||
if (t < 0.0) t += 1.0;
|
||||
if (t > 1.0) t -= 1.0;
|
||||
if (t < 1.0 / 6.0) return p + (q - p) * 6.0 * t;
|
||||
if (t < 1.0 / 2.0) return q;
|
||||
if (t < 2.0 / 3.0) return p + (q - p) * (2.0 / 3.0 - t) * 6.0;
|
||||
return p;
|
||||
}
|
||||
|
||||
vec3 hsl2rgb(vec3 hsl) {
|
||||
float h = hsl.x, s = hsl.y, l = hsl.z;
|
||||
if (s == 0.0) return vec3(l);
|
||||
float q = l < 0.5 ? l * (1.0 + s) : l + s - l * s;
|
||||
float p = 2.0 * l - q;
|
||||
return vec3(
|
||||
hue2rgb(p, q, h + 1.0 / 3.0),
|
||||
hue2rgb(p, q, h),
|
||||
hue2rgb(p, q, h - 1.0 / 3.0)
|
||||
);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 tex = texture(u_image0, v_texCoord);
|
||||
vec3 color = tex.rgb;
|
||||
|
||||
vec3 shadows = vec3(u_float0, u_float1, u_float2) * 0.01;
|
||||
vec3 midtones = vec3(u_float3, u_float4, u_float5) * 0.01;
|
||||
vec3 highlights = vec3(u_float6, u_float7, u_float8) * 0.01;
|
||||
|
||||
float maxC = max(color.r, max(color.g, color.b));
|
||||
float minC = min(color.r, min(color.g, color.b));
|
||||
float lightness = (maxC + minC) * 0.5;
|
||||
|
||||
// GIMP weight curves: linear ramps with constants a=0.25, b=0.333, scale=0.7
|
||||
const float a = 0.25;
|
||||
const float b = 0.333;
|
||||
const float scale = 0.7;
|
||||
|
||||
float sw = clamp((lightness - b) / -a + 0.5, 0.0, 1.0) * scale;
|
||||
float mw = clamp((lightness - b) / a + 0.5, 0.0, 1.0) *
|
||||
clamp((lightness + b - 1.0) / -a + 0.5, 0.0, 1.0) * scale;
|
||||
float hw = clamp((lightness + b - 1.0) / a + 0.5, 0.0, 1.0) * scale;
|
||||
|
||||
color += sw * shadows + mw * midtones + hw * highlights;
|
||||
|
||||
if (u_bool0) {
|
||||
vec3 hsl = rgb2hsl(clamp(color, 0.0, 1.0));
|
||||
hsl.z = lightness;
|
||||
color = hsl2rgb(hsl);
|
||||
}
|
||||
|
||||
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform sampler2D u_curve0; // RGB master curve (256x1 LUT)
|
||||
uniform sampler2D u_curve1; // Red channel curve
|
||||
uniform sampler2D u_curve2; // Green channel curve
|
||||
uniform sampler2D u_curve3; // Blue channel curve
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
// GIMP-compatible curve lookup with manual linear interpolation.
|
||||
// Matches gimp_curve_map_value_inline() from gimpcurve-map.c:
|
||||
// index = value * (n_samples - 1)
|
||||
// f = fract(index)
|
||||
// result = (1-f) * samples[floor] + f * samples[ceil]
|
||||
//
|
||||
// Uses texelFetch (NEAREST) to avoid GPU half-texel offset issues
|
||||
// that occur with texture() + GL_LINEAR on small 256x1 LUTs.
|
||||
float applyCurve(sampler2D curve, float value) {
|
||||
value = clamp(value, 0.0, 1.0);
|
||||
|
||||
float pos = value * 255.0;
|
||||
int lo = int(floor(pos));
|
||||
int hi = min(lo + 1, 255);
|
||||
float f = pos - float(lo);
|
||||
|
||||
float a = texelFetch(curve, ivec2(lo, 0), 0).r;
|
||||
float b = texelFetch(curve, ivec2(hi, 0), 0).r;
|
||||
|
||||
return a + f * (b - a);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 color = texture(u_image0, v_texCoord);
|
||||
|
||||
// GIMP order: per-channel curves first, then RGB master curve.
|
||||
// See gimp_curve_map_pixels() default case in gimpcurve-map.c:
|
||||
// dest = colors_curve( channel_curve( src ) )
|
||||
color.r = applyCurve(u_curve0, applyCurve(u_curve1, color.r));
|
||||
color.g = applyCurve(u_curve0, applyCurve(u_curve2, color.g));
|
||||
color.b = applyCurve(u_curve0, applyCurve(u_curve3, color.b));
|
||||
|
||||
fragColor0 = vec4(color.rgb, color.a);
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -928,7 +928,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is None:
|
||||
logging.warning(f"Missing weight for layer {layer_name}")
|
||||
self.weight = None
|
||||
return
|
||||
|
||||
manually_loaded_keys = [weight_key]
|
||||
@@ -1035,9 +1034,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
if self.bias is not None:
|
||||
sd["{}bias".format(prefix)] = self.bias
|
||||
|
||||
if self.weight is None:
|
||||
return sd
|
||||
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||
for k in sd_out:
|
||||
|
||||
33
comfy/sd.py
33
comfy/sd.py
@@ -61,7 +61,6 @@ import comfy.text_encoders.newbie
|
||||
import comfy.text_encoders.anima
|
||||
import comfy.text_encoders.ace15
|
||||
import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.qwen35
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -426,13 +425,13 @@ class CLIP:
|
||||
def get_key_patches(self):
|
||||
return self.patcher.get_key_patches()
|
||||
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
|
||||
self.load_model(tokens)
|
||||
self.cond_stage_model.set_clip_options({"layer": None})
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=True):
|
||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
@@ -1229,11 +1228,6 @@ class TEModel(Enum):
|
||||
QWEN3_8B = 20
|
||||
QWEN3_06B = 21
|
||||
GEMMA_3_4B_VISION = 22
|
||||
QWEN35_08B = 23
|
||||
QWEN35_2B = 24
|
||||
QWEN35_4B = 25
|
||||
QWEN35_9B = 26
|
||||
QWEN35_27B = 27
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@@ -1273,17 +1267,6 @@ def detect_te_model(sd):
|
||||
return TEModel.QWEN25_3B
|
||||
if weight.shape[0] == 512:
|
||||
return TEModel.QWEN25_7B
|
||||
if "model.language_model.layers.0.linear_attn.A_log" in sd and "model.language_model.layers.0.input_layernorm.weight" in sd:
|
||||
weight = sd['model.language_model.layers.0.input_layernorm.weight']
|
||||
if weight.shape[0] == 1024:
|
||||
return TEModel.QWEN35_08B
|
||||
if weight.shape[0] == 2560:
|
||||
return TEModel.QWEN35_4B
|
||||
if weight.shape[0] == 4096:
|
||||
return TEModel.QWEN35_9B
|
||||
if weight.shape[0] == 5120:
|
||||
return TEModel.QWEN35_27B
|
||||
return TEModel.QWEN35_2B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
@@ -1316,12 +1299,11 @@ def t5xxl_detect(clip_data):
|
||||
return {}
|
||||
|
||||
def llama_detect(clip_data):
|
||||
weight_names = ["model.layers.0.self_attn.k_proj.weight", "model.layers.0.linear_attn.in_proj_a.weight"]
|
||||
weight_name = "model.layers.0.self_attn.k_proj.weight"
|
||||
|
||||
for sd in clip_data:
|
||||
for weight_name in weight_names:
|
||||
if weight_name in sd:
|
||||
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
|
||||
if weight_name in sd:
|
||||
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
|
||||
|
||||
return {}
|
||||
|
||||
@@ -1449,11 +1431,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif te_model == TEModel.JINA_CLIP_2:
|
||||
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
|
||||
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
|
||||
elif te_model in (TEModel.QWEN35_08B, TEModel.QWEN35_2B, TEModel.QWEN35_4B, TEModel.QWEN35_9B, TEModel.QWEN35_27B):
|
||||
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
|
||||
qwen35_type = {TEModel.QWEN35_08B: "qwen35_08b", TEModel.QWEN35_2B: "qwen35_2b", TEModel.QWEN35_4B: "qwen35_4b", TEModel.QWEN35_9B: "qwen35_9b", TEModel.QWEN35_27B: "qwen35_27b"}[te_model]
|
||||
clip_target.clip = comfy.text_encoders.qwen35.te(**llama_detect(clip_data), model_type=qwen35_type)
|
||||
clip_target.tokenizer = comfy.text_encoders.qwen35.tokenizer(model_type=qwen35_type)
|
||||
elif te_model == TEModel.QWEN3_06B:
|
||||
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
|
||||
|
||||
@@ -308,14 +308,14 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
def load_sd(self, sd):
|
||||
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0):
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
if isinstance(tokens, dict):
|
||||
tokens_only = next(iter(tokens.values())) # todo: get this better?
|
||||
else:
|
||||
tokens_only = tokens
|
||||
tokens_only = [[t[0] for t in b] for b in tokens_only]
|
||||
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=presence_penalty)
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
||||
|
||||
def parse_parentheses(string):
|
||||
result = []
|
||||
@@ -740,5 +740,5 @@ class SD1ClipModel(torch.nn.Module):
|
||||
def load_sd(self, sd):
|
||||
return getattr(self, self.clip).load_sd(sd)
|
||||
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
|
||||
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
||||
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
||||
|
||||
@@ -224,7 +224,7 @@ class Qwen3_8BConfig:
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = True
|
||||
lm_head: bool = False
|
||||
stop_tokens = [151643, 151645]
|
||||
|
||||
@dataclass
|
||||
@@ -655,17 +655,6 @@ class Llama2_(nn.Module):
|
||||
if config.lm_head:
|
||||
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def get_past_len(self, past_key_values):
|
||||
return past_key_values[0][2]
|
||||
|
||||
def compute_freqs_cis(self, position_ids, device):
|
||||
return precompute_freqs_cis(self.config.head_dim,
|
||||
position_ids,
|
||||
self.config.rope_theta,
|
||||
self.config.rope_scale,
|
||||
self.config.rope_dims,
|
||||
device=device)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
@@ -678,12 +667,17 @@ class Llama2_(nn.Module):
|
||||
seq_len = x.shape[1]
|
||||
past_len = 0
|
||||
if past_key_values is not None and len(past_key_values) > 0:
|
||||
past_len = self.get_past_len(past_key_values)
|
||||
past_len = past_key_values[0][2]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
|
||||
|
||||
freqs_cis = self.compute_freqs_cis(position_ids, x.device)
|
||||
freqs_cis = precompute_freqs_cis(self.config.head_dim,
|
||||
position_ids,
|
||||
self.config.rope_theta,
|
||||
self.config.rope_scale,
|
||||
self.config.rope_dims,
|
||||
device=x.device)
|
||||
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
@@ -818,16 +812,9 @@ class BaseGenerate:
|
||||
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
|
||||
return x
|
||||
|
||||
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
|
||||
model_config = self.model.config
|
||||
past_key_values = []
|
||||
for x in range(model_config.num_hidden_layers):
|
||||
past_key_values.append((torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
return past_key_values
|
||||
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0):
|
||||
device = embeds.device
|
||||
model_config = self.model.config
|
||||
|
||||
if stop_tokens is None:
|
||||
stop_tokens = self.model.config.stop_tokens
|
||||
@@ -842,8 +829,11 @@ class BaseGenerate:
|
||||
if embeds.ndim == 2:
|
||||
embeds = embeds.unsqueeze(0)
|
||||
|
||||
past_key_values = [] #kv_cache init
|
||||
max_cache_len = embeds.shape[1] + max_length
|
||||
past_key_values = self.init_kv_cache(embeds.shape[0], max_cache_len, device, execution_dtype)
|
||||
for x in range(model_config.num_hidden_layers):
|
||||
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
|
||||
|
||||
@@ -854,7 +844,7 @@ class BaseGenerate:
|
||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
|
||||
logits = self.logits(x)[:, -1]
|
||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
|
||||
token_id = next_token[0].item()
|
||||
generated_token_ids.append(token_id)
|
||||
|
||||
@@ -866,7 +856,7 @@ class BaseGenerate:
|
||||
|
||||
return generated_token_ids
|
||||
|
||||
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True, presence_penalty=0.0):
|
||||
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
|
||||
|
||||
if not do_sample or temperature == 0.0:
|
||||
return torch.argmax(logits, dim=-1, keepdim=True)
|
||||
@@ -877,11 +867,6 @@ class BaseGenerate:
|
||||
for token_id in set(token_history):
|
||||
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
|
||||
|
||||
if presence_penalty is not None and presence_penalty != 0.0:
|
||||
for i in range(logits.shape[0]):
|
||||
for token_id in set(token_history):
|
||||
logits[i, token_id] -= presence_penalty
|
||||
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
|
||||
@@ -912,9 +897,6 @@ class BaseGenerate:
|
||||
class BaseQwen3:
|
||||
def logits(self, x):
|
||||
input = x[:, -1:]
|
||||
if self.model.config.lm_head:
|
||||
return self.model.lm_head(input)
|
||||
|
||||
module = self.model.embed_tokens
|
||||
|
||||
offload_stream = None
|
||||
|
||||
@@ -91,11 +91,11 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
|
||||
self.dtypes.add(dtype)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
tokens_only = [[t[0] for t in b] for b in tokens]
|
||||
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
|
||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
|
||||
|
||||
class DualLinearProjection(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
|
||||
@@ -189,8 +189,8 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
|
||||
return out.to(device=out_device, dtype=torch.float), pooled, extra
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
|
||||
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty)
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
||||
|
||||
@@ -1,833 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass, field
|
||||
import os
|
||||
import math
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.qwen_vl
|
||||
|
||||
from .llama import BaseLlama, BaseGenerate, Llama2_, MLP, RMSNorm, apply_rope
|
||||
|
||||
|
||||
def _qwen35_layer_types(n):
|
||||
return [("full_attention" if (i + 1) % 4 == 0 else "linear_attention") for i in range(n)]
|
||||
|
||||
@dataclass
|
||||
class Qwen35Config:
|
||||
vocab_size: int = 248320
|
||||
hidden_size: int = 2048
|
||||
intermediate_size: int = 6144
|
||||
num_hidden_layers: int = 24
|
||||
# Full attention params
|
||||
num_attention_heads: int = 8
|
||||
num_key_value_heads: int = 2
|
||||
head_dim: int = 256
|
||||
partial_rotary_factor: float = 0.25
|
||||
# Linear attention (DeltaNet) params
|
||||
linear_num_key_heads: int = 16
|
||||
linear_num_value_heads: int = 16
|
||||
linear_key_head_dim: int = 128
|
||||
linear_value_head_dim: int = 128
|
||||
conv_kernel_size: int = 4
|
||||
# Shared params
|
||||
max_position_embeddings: int = 32768
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 10000000.0
|
||||
mrope_section: list = field(default_factory=lambda: [11, 11, 10])
|
||||
layer_types: list = field(default_factory=lambda: _qwen35_layer_types(24))
|
||||
rms_norm_add: bool = True
|
||||
mlp_activation: str = "silu"
|
||||
qkv_bias: bool = False
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
stop_tokens: list = field(default_factory=lambda: [248044, 248046])
|
||||
# These are needed for BaseLlama/BaseGenerate compatibility but unused directly
|
||||
transformer_type: str = "qwen35_2b"
|
||||
rope_dims: list = None
|
||||
rope_scale: float = None
|
||||
|
||||
QWEN35_VISION_DEFAULTS = dict(hidden_size=1024, num_heads=16, intermediate_size=4096, depth=24, patch_size=16, temporal_patch_size=2, in_channels=3, spatial_merge_size=2, num_position_embeddings=2304)
|
||||
|
||||
QWEN35_MODELS = {
|
||||
"qwen35_08b": dict(hidden_size=1024, intermediate_size=3584, vision=dict(hidden_size=768, num_heads=12, intermediate_size=3072, depth=12)),
|
||||
"qwen35_2b": dict(hidden_size=2048, intermediate_size=6144, num_hidden_layers=24, num_attention_heads=8, num_key_value_heads=2, linear_num_value_heads=16),
|
||||
"qwen35_4b": dict(hidden_size=2560, intermediate_size=9216, num_hidden_layers=32, num_attention_heads=16, num_key_value_heads=4, linear_num_value_heads=32),
|
||||
"qwen35_9b": dict(hidden_size=4096, intermediate_size=12288, num_hidden_layers=32, num_attention_heads=16, num_key_value_heads=4, linear_num_value_heads=32, lm_head=True, vision=dict(hidden_size=1152, intermediate_size=4304, depth=27)),
|
||||
"qwen35_27b": dict(hidden_size=5120, intermediate_size=17408, num_hidden_layers=64, num_attention_heads=24, num_key_value_heads=4, linear_num_value_heads=48, lm_head=True, vision=dict(hidden_size=1152, intermediate_size=4304, depth=27)),
|
||||
}
|
||||
|
||||
|
||||
def _make_config(model_type, config_dict={}):
|
||||
overrides = QWEN35_MODELS.get(model_type, {}).copy()
|
||||
overrides.pop("vision", None)
|
||||
if "num_hidden_layers" in overrides:
|
||||
overrides["layer_types"] = _qwen35_layer_types(overrides["num_hidden_layers"])
|
||||
overrides.update(config_dict)
|
||||
return Qwen35Config(**overrides)
|
||||
|
||||
|
||||
class RMSNormGated(RMSNorm):
|
||||
def forward(self, x, gate):
|
||||
return super().forward(x) * F.silu(gate.to(x.dtype))
|
||||
|
||||
def torch_chunk_gated_delta_rule(query, key, value, g, beta, chunk_size=64, initial_state=None, output_final_state=False):
|
||||
initial_dtype = query.dtype
|
||||
query = F.normalize(query, dim=-1)
|
||||
key = F.normalize(key, dim=-1)
|
||||
query, key, value, beta, g = [x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)]
|
||||
|
||||
batch_size, num_heads, sequence_length, k_head_dim = key.shape
|
||||
v_head_dim = value.shape[-1]
|
||||
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
||||
query = F.pad(query, (0, 0, 0, pad_size))
|
||||
key = F.pad(key, (0, 0, 0, pad_size))
|
||||
value = F.pad(value, (0, 0, 0, pad_size))
|
||||
beta = F.pad(beta, (0, pad_size))
|
||||
g = F.pad(g, (0, pad_size))
|
||||
total_sequence_length = sequence_length + pad_size
|
||||
scale = 1 / (query.shape[-1] ** 0.5)
|
||||
query = query * scale
|
||||
|
||||
v_beta = value * beta.unsqueeze(-1)
|
||||
k_beta = key * beta.unsqueeze(-1)
|
||||
query, key, value, k_beta, v_beta = [x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)]
|
||||
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
||||
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
|
||||
|
||||
g = g.cumsum(dim=-1)
|
||||
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
|
||||
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
|
||||
for i in range(1, chunk_size):
|
||||
row = attn[..., i, :i].clone()
|
||||
sub = attn[..., :i, :i].clone()
|
||||
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
||||
value = attn @ v_beta
|
||||
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
||||
last_recurrent_state = (
|
||||
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
|
||||
if initial_state is None
|
||||
else initial_state.to(value)
|
||||
)
|
||||
core_attn_out = torch.zeros_like(value)
|
||||
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
|
||||
|
||||
for i in range(0, total_sequence_length // chunk_size):
|
||||
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
|
||||
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
||||
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
||||
v_new = v_i - v_prime
|
||||
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
||||
core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
||||
last_recurrent_state = (
|
||||
last_recurrent_state * g[:, :, i, -1, None, None].exp()
|
||||
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
|
||||
)
|
||||
|
||||
if not output_final_state:
|
||||
last_recurrent_state = None
|
||||
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
|
||||
core_attn_out = core_attn_out[:, :, :sequence_length]
|
||||
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
|
||||
return core_attn_out, last_recurrent_state
|
||||
|
||||
|
||||
def torch_causal_conv1d_update(x, conv_state, weight, bias=None):
|
||||
# conv_state: [B, channels, kernel_size-1], x: [B, channels, 1]
|
||||
# weight: [channels, kernel_size]
|
||||
state_len = conv_state.shape[-1]
|
||||
combined = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # [B, channels, kernel_size]
|
||||
conv_state.copy_(combined[:, :, -state_len:])
|
||||
out = (combined * weight).sum(dim=-1, keepdim=True) # [B, channels, 1]
|
||||
if bias is not None:
|
||||
out = out + bias.unsqueeze(0).unsqueeze(-1)
|
||||
return F.silu(out).to(x.dtype)
|
||||
|
||||
|
||||
# GatedDeltaNet - Linear Attention Layer
|
||||
|
||||
class GatedDeltaNet(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
hidden = config.hidden_size
|
||||
self.num_key_heads = config.linear_num_key_heads
|
||||
self.num_value_heads = config.linear_num_value_heads
|
||||
self.key_head_dim = config.linear_key_head_dim
|
||||
self.value_head_dim = config.linear_value_head_dim
|
||||
self.conv_kernel_size = config.conv_kernel_size
|
||||
|
||||
key_dim = self.num_key_heads * self.key_head_dim
|
||||
value_dim = self.num_value_heads * self.value_head_dim
|
||||
self.key_dim = key_dim
|
||||
self.value_dim = value_dim
|
||||
conv_dim = key_dim * 2 + value_dim
|
||||
|
||||
self.in_proj_qkv = ops.Linear(hidden, conv_dim, bias=False, device=device, dtype=dtype)
|
||||
self.in_proj_z = ops.Linear(hidden, value_dim, bias=False, device=device, dtype=dtype)
|
||||
self.in_proj_b = ops.Linear(hidden, self.num_value_heads, bias=False, device=device, dtype=dtype)
|
||||
self.in_proj_a = ops.Linear(hidden, self.num_value_heads, bias=False, device=device, dtype=dtype)
|
||||
self.out_proj = ops.Linear(value_dim, hidden, bias=False, device=device, dtype=dtype)
|
||||
|
||||
self.dt_bias = nn.Parameter(torch.empty(self.num_value_heads, device=device, dtype=dtype))
|
||||
self.A_log = nn.Parameter(torch.empty(self.num_value_heads, device=device, dtype=dtype))
|
||||
|
||||
self.conv1d = ops.Conv1d(in_channels=conv_dim, out_channels=conv_dim, bias=False, kernel_size=self.conv_kernel_size,
|
||||
groups=conv_dim, padding=self.conv_kernel_size - 1, device=device, dtype=dtype)
|
||||
|
||||
self.norm = RMSNormGated(self.value_head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, past_key_value=None, **kwargs):
|
||||
batch_size, seq_len, _ = x.shape
|
||||
|
||||
use_recurrent = (
|
||||
past_key_value is not None
|
||||
and past_key_value[2] > 0
|
||||
and seq_len == 1
|
||||
)
|
||||
|
||||
# Projections (shared)
|
||||
mixed_qkv = self.in_proj_qkv(x).transpose(1, 2) # [B, conv_dim, seq_len]
|
||||
z = self.in_proj_z(x)
|
||||
b = self.in_proj_b(x)
|
||||
a = self.in_proj_a(x)
|
||||
|
||||
# Conv1d
|
||||
if use_recurrent:
|
||||
recurrent_state, conv_state, step_index = past_key_value
|
||||
conv_weight = comfy.model_management.cast_to_device(self.conv1d.weight, mixed_qkv.device, mixed_qkv.dtype).squeeze(1)
|
||||
conv_bias = comfy.model_management.cast_to_device(self.conv1d.bias, mixed_qkv.device, mixed_qkv.dtype) if self.conv1d.bias is not None else None
|
||||
mixed_qkv = torch_causal_conv1d_update(mixed_qkv, conv_state, conv_weight, conv_bias)
|
||||
else:
|
||||
if past_key_value is not None:
|
||||
recurrent_state, conv_state, step_index = past_key_value
|
||||
conv_state_init = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
|
||||
conv_state.copy_(conv_state_init[:, :, -conv_state.shape[-1]:])
|
||||
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
|
||||
|
||||
# Split QKV and compute beta/g
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2) # [B, seq_len, conv_dim]
|
||||
query, key, value = mixed_qkv.split([self.key_dim, self.key_dim, self.value_dim], dim=-1)
|
||||
beta = b.sigmoid()
|
||||
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias.float())
|
||||
|
||||
# Delta rule
|
||||
if use_recurrent:
|
||||
# single-token path: work in [B, heads, dim] without seq dim
|
||||
query = query.reshape(batch_size, self.num_key_heads, self.key_head_dim)
|
||||
key = key.reshape(batch_size, self.num_key_heads, self.key_head_dim)
|
||||
value = value.reshape(batch_size, self.num_value_heads, self.value_head_dim)
|
||||
|
||||
if self.num_value_heads != self.num_key_heads:
|
||||
rep = self.num_value_heads // self.num_key_heads
|
||||
query = query.repeat_interleave(rep, dim=1)
|
||||
key = key.repeat_interleave(rep, dim=1)
|
||||
|
||||
scale = self.key_head_dim ** -0.5
|
||||
q = F.normalize(query.float(), dim=-1) * scale
|
||||
k = F.normalize(key.float(), dim=-1)
|
||||
v = value.float()
|
||||
beta_t = beta.reshape(batch_size, -1)
|
||||
g_t = g.reshape(batch_size, -1).exp()
|
||||
|
||||
# In-place state update: [B, heads, k_dim, v_dim]
|
||||
recurrent_state.mul_(g_t[:, :, None, None])
|
||||
kv_mem = torch.einsum('bhk,bhkv->bhv', k, recurrent_state)
|
||||
delta = (v - kv_mem) * beta_t[:, :, None]
|
||||
recurrent_state.add_(k.unsqueeze(-1) * delta.unsqueeze(-2))
|
||||
core_attn_out = torch.einsum('bhk,bhkv->bhv', q, recurrent_state)
|
||||
|
||||
core_attn_out = core_attn_out.to(x.dtype).unsqueeze(1)
|
||||
present_key_value = (recurrent_state, conv_state, step_index + 1)
|
||||
else:
|
||||
query = query.reshape(batch_size, seq_len, -1, self.key_head_dim)
|
||||
key = key.reshape(batch_size, seq_len, -1, self.key_head_dim)
|
||||
value = value.reshape(batch_size, seq_len, -1, self.value_head_dim)
|
||||
|
||||
if self.num_value_heads != self.num_key_heads:
|
||||
rep = self.num_value_heads // self.num_key_heads
|
||||
query = query.repeat_interleave(rep, dim=2)
|
||||
key = key.repeat_interleave(rep, dim=2)
|
||||
|
||||
core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule(
|
||||
query, key, value, g=g, beta=beta,
|
||||
initial_state=None,
|
||||
output_final_state=past_key_value is not None,
|
||||
)
|
||||
|
||||
present_key_value = None
|
||||
if past_key_value is not None:
|
||||
if last_recurrent_state is not None:
|
||||
recurrent_state.copy_(last_recurrent_state.to(recurrent_state.dtype))
|
||||
present_key_value = (recurrent_state, conv_state, step_index + seq_len)
|
||||
|
||||
# Gated norm + output projection (shared)
|
||||
core_attn_out = self.norm(core_attn_out.reshape(-1, self.value_head_dim), z.reshape(-1, self.value_head_dim))
|
||||
output = self.out_proj(core_attn_out.reshape(batch_size, seq_len, -1))
|
||||
return output, present_key_value
|
||||
|
||||
|
||||
# GatedAttention - Full Attention with output gating
|
||||
def precompute_partial_rope(head_dim, rotary_dim, position_ids, theta, device=None, mrope_section=None):
|
||||
"""Compute RoPE frequencies for partial rotary embeddings."""
|
||||
theta_numerator = torch.arange(0, rotary_dim, 2, device=device).float()
|
||||
inv_freq = 1.0 / (theta ** (theta_numerator / rotary_dim))
|
||||
|
||||
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
if mrope_section is not None and position_ids.shape[0] == 3:
|
||||
mrope_section_2 = [s * 2 for s in mrope_section]
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section_2, dim=-1))], dim=-1).unsqueeze(0)
|
||||
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section_2, dim=-1))], dim=-1).unsqueeze(0)
|
||||
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
sin_split = sin.shape[-1] // 2
|
||||
return (cos, sin[..., :sin_split], -sin[..., sin_split:])
|
||||
|
||||
|
||||
def apply_partial_rope(xq, xk, freqs_cis, rotary_dim):
|
||||
"""Apply RoPE to only the first rotary_dim dimensions."""
|
||||
xq_rot = xq[..., :rotary_dim]
|
||||
xq_pass = xq[..., rotary_dim:]
|
||||
xk_rot = xk[..., :rotary_dim]
|
||||
xk_pass = xk[..., rotary_dim:]
|
||||
|
||||
xq_rot, xk_rot = apply_rope(xq_rot, xk_rot, freqs_cis)
|
||||
|
||||
xq = torch.cat([xq_rot, xq_pass], dim=-1)
|
||||
xk = torch.cat([xk_rot, xk_pass], dim=-1)
|
||||
return xq, xk
|
||||
|
||||
|
||||
class GatedAttention(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.hidden_size = config.hidden_size
|
||||
self.inner_size = self.num_heads * self.head_dim
|
||||
self.rotary_dim = int(self.head_dim * config.partial_rotary_factor)
|
||||
|
||||
# q_proj outputs 2x: query + gate
|
||||
self.q_proj = ops.Linear(config.hidden_size, self.inner_size * 2, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
# QK norms with (1+weight) scaling
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, past_key_value=None):
|
||||
batch_size, seq_length, _ = x.shape
|
||||
|
||||
# Project Q (with gate), K, V
|
||||
qg = self.q_proj(x)
|
||||
# Split into query and gate: each is [B, seq, inner_size]
|
||||
qg = qg.view(batch_size, seq_length, self.num_heads, self.head_dim * 2)
|
||||
xq, gate = qg[..., :self.head_dim], qg[..., self.head_dim:]
|
||||
gate = gate.reshape(batch_size, seq_length, -1) # [B, seq, inner_size]
|
||||
|
||||
xk = self.k_proj(x)
|
||||
xv = self.v_proj(x)
|
||||
|
||||
xq = self.q_norm(xq).transpose(1, 2) # [B, heads, seq, head_dim]
|
||||
xk = self.k_norm(xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim)).transpose(1, 2)
|
||||
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Apply partial RoPE
|
||||
xq, xk = apply_partial_rope(xq, xk, freqs_cis, self.rotary_dim)
|
||||
|
||||
# KV cache
|
||||
present_key_value = None
|
||||
if past_key_value is not None:
|
||||
past_key, past_value, index = past_key_value
|
||||
num_tokens = xk.shape[2]
|
||||
if past_key.shape[2] >= (index + num_tokens):
|
||||
past_key[:, :, index:index + num_tokens] = xk
|
||||
past_value[:, :, index:index + num_tokens] = xv
|
||||
xk = past_key[:, :, :index + num_tokens]
|
||||
xv = past_value[:, :, :index + num_tokens]
|
||||
present_key_value = (past_key, past_value, index + num_tokens)
|
||||
else:
|
||||
if index > 0:
|
||||
xk = torch.cat((past_key[:, :, :index], xk), dim=2)
|
||||
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
|
||||
present_key_value = (xk, xv, index + num_tokens)
|
||||
|
||||
# Expand KV heads for GQA
|
||||
if self.num_heads != self.num_kv_heads:
|
||||
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
|
||||
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
||||
output = output * gate.sigmoid()
|
||||
|
||||
return self.o_proj(output), present_key_value
|
||||
|
||||
|
||||
# Hybrid Transformer Block
|
||||
class Qwen35TransformerBlock(nn.Module):
|
||||
def __init__(self, config, index, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.layer_type = config.layer_types[index]
|
||||
if self.layer_type == "linear_attention":
|
||||
self.linear_attn = GatedDeltaNet(config, device=device, dtype=dtype, ops=ops)
|
||||
else:
|
||||
self.self_attn = GatedAttention(config, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, past_key_value=None):
|
||||
if self.layer_type == "linear_attention":
|
||||
h, present_key_value = self.linear_attn(self.input_layernorm(x), attention_mask=attention_mask, past_key_value=past_key_value)
|
||||
else:
|
||||
h, present_key_value = self.self_attn(self.input_layernorm(x), attention_mask=attention_mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, past_key_value=past_key_value)
|
||||
|
||||
x = x + h
|
||||
x = x + self.mlp(self.post_attention_layernorm(x))
|
||||
return x, present_key_value
|
||||
|
||||
|
||||
# Qwen35 Transformer Backbone
|
||||
class Qwen35Transformer(Llama2_):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.normalize_in = False
|
||||
|
||||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||
self.layers = nn.ModuleList([
|
||||
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||
for i in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
if config.final_norm:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
if config.lm_head:
|
||||
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def get_past_len(self, past_key_values):
|
||||
for i, layer in enumerate(self.layers):
|
||||
if layer.layer_type == "full_attention":
|
||||
if len(past_key_values) > i:
|
||||
return past_key_values[i][2]
|
||||
break
|
||||
return 0
|
||||
|
||||
def compute_freqs_cis(self, position_ids, device):
|
||||
rotary_dim = int(self.config.head_dim * self.config.partial_rotary_factor)
|
||||
return precompute_partial_rope(
|
||||
self.config.head_dim, rotary_dim, position_ids,
|
||||
self.config.rope_theta, device=device,
|
||||
mrope_section=self.config.mrope_section,
|
||||
)
|
||||
|
||||
|
||||
# Vision Encoder
|
||||
class Qwen35VisionPatchEmbed(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.patch_size = config["patch_size"]
|
||||
self.temporal_patch_size = config["temporal_patch_size"]
|
||||
self.in_channels = config["in_channels"]
|
||||
self.embed_dim = config["hidden_size"]
|
||||
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
|
||||
self.proj = ops.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
target_dtype = self.proj.weight.dtype
|
||||
x = x.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size)
|
||||
return self.proj(x.to(target_dtype)).view(-1, self.embed_dim)
|
||||
|
||||
|
||||
class Qwen35VisionMLP(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
self.linear_fc1 = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
|
||||
self.linear_fc2 = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
return self.linear_fc2(F.gelu(self.linear_fc1(hidden_state), approximate="tanh"))
|
||||
|
||||
|
||||
class Qwen35VisionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, theta=10000.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
def forward(self, seqlen):
|
||||
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.outer(seq, self.inv_freq)
|
||||
return freqs
|
||||
|
||||
|
||||
class Qwen35VisionAttention(nn.Module):
|
||||
def __init__(self, hidden_size, num_heads, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
self.dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.dim // self.num_heads
|
||||
self.qkv = ops.Linear(self.dim, self.dim * 3, bias=True, device=device, dtype=dtype)
|
||||
self.proj = ops.Linear(self.dim, self.dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, cu_seqlens, position_embeddings, optimized_attention=None):
|
||||
seq_length = x.shape[0]
|
||||
query_states, key_states, value_states = (
|
||||
self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
)
|
||||
query_states, key_states = apply_rope(query_states, key_states, position_embeddings)
|
||||
|
||||
# Process per-sequence attention
|
||||
lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
q_splits = torch.split(query_states, lengths, dim=0)
|
||||
k_splits = torch.split(key_states, lengths, dim=0)
|
||||
v_splits = torch.split(value_states, lengths, dim=0)
|
||||
|
||||
attn_outputs = []
|
||||
for q, k, v in zip(q_splits, k_splits, v_splits):
|
||||
q = q.transpose(0, 1).unsqueeze(0)
|
||||
k = k.transpose(0, 1).unsqueeze(0)
|
||||
v = v.transpose(0, 1).unsqueeze(0)
|
||||
attn_outputs.append(optimized_attention(q, k, v, self.num_heads, skip_reshape=True))
|
||||
|
||||
attn_output = torch.cat(attn_outputs, dim=1)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
return self.proj(attn_output)
|
||||
|
||||
|
||||
class Qwen35VisionBlock(nn.Module):
|
||||
def __init__(self, hidden_size, num_heads, intermediate_size, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||
self.norm2 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||
self.attn = Qwen35VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = Qwen35VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def forward(self, x, cu_seqlens, position_embeddings, optimized_attention=None):
|
||||
x = x + self.attn(self.norm1(x), cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
|
||||
return x + self.mlp(self.norm2(x))
|
||||
|
||||
|
||||
class Qwen35VisionPatchMerger(nn.Module):
|
||||
def __init__(self, hidden_size, spatial_merge_size, out_hidden_size, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
merge_dim = hidden_size * (spatial_merge_size ** 2)
|
||||
self.norm = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||
self.linear_fc1 = ops.Linear(merge_dim, merge_dim, device=device, dtype=dtype)
|
||||
self.linear_fc2 = ops.Linear(merge_dim, out_hidden_size, device=device, dtype=dtype)
|
||||
self.merge_dim = merge_dim
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x).view(-1, self.merge_dim)
|
||||
return self.linear_fc2(F.gelu(self.linear_fc1(x)))
|
||||
|
||||
|
||||
class Qwen35VisionModel(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.spatial_merge_size = config["spatial_merge_size"]
|
||||
self.patch_size = config["patch_size"]
|
||||
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
|
||||
|
||||
self.hidden_size = config["hidden_size"]
|
||||
self.num_heads = config["num_heads"]
|
||||
self.num_position_embeddings = config["num_position_embeddings"]
|
||||
|
||||
self.patch_embed = Qwen35VisionPatchEmbed(config, device=device, dtype=dtype, ops=ops)
|
||||
self.pos_embed = ops.Embedding(self.num_position_embeddings, self.hidden_size, device=device, dtype=dtype)
|
||||
self.num_grid_per_side = int(self.num_position_embeddings ** 0.5)
|
||||
self.rotary_pos_emb = Qwen35VisionRotaryEmbedding(self.hidden_size // self.num_heads // 2)
|
||||
self.blocks = nn.ModuleList([
|
||||
Qwen35VisionBlock(self.hidden_size, self.num_heads, config["intermediate_size"], device=device, dtype=dtype, ops=ops)
|
||||
for _ in range(config["depth"])
|
||||
])
|
||||
self.merger = Qwen35VisionPatchMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def rot_pos_emb(self, grid_thw):
|
||||
merge_size = self.spatial_merge_size
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
max_hw = max(max(h, w) for _, h, w in grid_thw_list)
|
||||
freq_table = self.rotary_pos_emb(max_hw)
|
||||
device = freq_table.device
|
||||
total_tokens = sum(int(t * h * w) for t, h, w in grid_thw_list)
|
||||
pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
|
||||
offset = 0
|
||||
for num_frames, height, width in grid_thw_list:
|
||||
num_frames, height, width = int(num_frames), int(height), int(width)
|
||||
merged_h, merged_w = height // merge_size, width // merge_size
|
||||
block_rows = torch.arange(merged_h, device=device)
|
||||
block_cols = torch.arange(merged_w, device=device)
|
||||
intra_row = torch.arange(merge_size, device=device)
|
||||
intra_col = torch.arange(merge_size, device=device)
|
||||
row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
|
||||
col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
|
||||
row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
|
||||
col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
|
||||
coords = torch.stack((row_idx, col_idx), dim=-1)
|
||||
if num_frames > 1:
|
||||
coords = coords.repeat(num_frames, 1)
|
||||
num_tokens = coords.shape[0]
|
||||
pos_ids[offset:offset + num_tokens] = coords
|
||||
offset += num_tokens
|
||||
embeddings = freq_table[pos_ids]
|
||||
embeddings = embeddings.flatten(1)
|
||||
return embeddings
|
||||
|
||||
def fast_pos_embed_interpolate(self, grid_thw):
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
grid_ts = [int(row[0]) for row in grid_thw_list]
|
||||
grid_hs = [int(row[1]) for row in grid_thw_list]
|
||||
grid_ws = [int(row[2]) for row in grid_thw_list]
|
||||
device = self.pos_embed.weight.device
|
||||
idx_list = [[] for _ in range(4)]
|
||||
weight_list = [[] for _ in range(4)]
|
||||
for t, h, w in grid_thw_list:
|
||||
h, w = int(h), int(w)
|
||||
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h, device=device)
|
||||
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w, device=device)
|
||||
h_idxs_floor = h_idxs.int()
|
||||
w_idxs_floor = w_idxs.int()
|
||||
h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
|
||||
w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
|
||||
dh = h_idxs - h_idxs_floor
|
||||
dw = w_idxs - w_idxs_floor
|
||||
base_h = h_idxs_floor * self.num_grid_per_side
|
||||
base_h_ceil = h_idxs_ceil * self.num_grid_per_side
|
||||
indices = [
|
||||
(base_h[None].T + w_idxs_floor[None]).flatten(),
|
||||
(base_h[None].T + w_idxs_ceil[None]).flatten(),
|
||||
(base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
|
||||
(base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
|
||||
]
|
||||
weights = [
|
||||
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
|
||||
((1 - dh)[None].T * dw[None]).flatten(),
|
||||
(dh[None].T * (1 - dw)[None]).flatten(),
|
||||
(dh[None].T * dw[None]).flatten(),
|
||||
]
|
||||
for j in range(4):
|
||||
idx_list[j].extend(indices[j].tolist())
|
||||
weight_list[j].extend(weights[j].tolist())
|
||||
idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device)
|
||||
weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device)
|
||||
pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None]
|
||||
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
|
||||
patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
|
||||
patch_pos_embeds_permute = []
|
||||
merge_size = self.spatial_merge_size
|
||||
for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
|
||||
pos_embed = pos_embed.repeat(t, 1)
|
||||
pos_embed = (
|
||||
pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
|
||||
.permute(0, 1, 3, 2, 4, 5)
|
||||
.flatten(0, 4)
|
||||
)
|
||||
patch_pos_embeds_permute.append(pos_embed)
|
||||
return torch.cat(patch_pos_embeds_permute)
|
||||
|
||||
def forward(self, x, grid_thw):
|
||||
x = self.patch_embed(x)
|
||||
pos_embeds = self.fast_pos_embed_interpolate(grid_thw).to(x.device)
|
||||
x = x + pos_embeds
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
seq_len = x.shape[0]
|
||||
x = x.reshape(seq_len, -1)
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
cos = emb.cos().unsqueeze(-2)
|
||||
sin = emb.sin().unsqueeze(-2)
|
||||
sin_half = sin.shape[-1] // 2
|
||||
position_embeddings = (cos, sin[..., :sin_half], -sin[..., sin_half:])
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
).cumsum(dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
|
||||
for blk in self.blocks:
|
||||
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
|
||||
merged = self.merger(x)
|
||||
return merged
|
||||
|
||||
# Model Wrapper
|
||||
class Qwen35(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
model_type = "qwen35_2b"
|
||||
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = _make_config(self.model_type, config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.model = Qwen35Transformer(config, device=device, dtype=dtype, ops=operations)
|
||||
vision_overrides = QWEN35_MODELS.get(self.model_type, {}).get("vision", {})
|
||||
vision_config = {**QWEN35_VISION_DEFAULTS, **vision_overrides, "out_hidden_size": config.hidden_size}
|
||||
self.visual = Qwen35VisionModel(vision_config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
if embed["type"] == "image":
|
||||
image, grid = comfy.text_encoders.qwen_vl.process_qwen2vl_images(embed["data"], patch_size=16)
|
||||
return self.visual(image.to(device, dtype=torch.float32), grid), grid
|
||||
return None, None
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[], past_key_values=None):
|
||||
grid = None
|
||||
position_ids = None
|
||||
offset = 0
|
||||
for e in embeds_info:
|
||||
if e.get("type") == "image":
|
||||
grid = e.get("extra", None)
|
||||
start = e.get("index")
|
||||
if position_ids is None:
|
||||
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
|
||||
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
||||
end = e.get("size") + start
|
||||
len_max = int(grid.max()) // 2
|
||||
start_next = len_max + start
|
||||
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
|
||||
position_ids[0, start:end] = start + offset
|
||||
max_d = int(grid[0][1]) // 2
|
||||
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||
max_d = int(grid[0][2]) // 2
|
||||
position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
|
||||
offset += len_max - (end - start)
|
||||
|
||||
if grid is None:
|
||||
position_ids = None
|
||||
|
||||
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids, past_key_values=past_key_values)
|
||||
|
||||
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
|
||||
model_config = self.model.config
|
||||
past_key_values = []
|
||||
for i in range(model_config.num_hidden_layers):
|
||||
if model_config.layer_types[i] == "linear_attention":
|
||||
recurrent_state = torch.zeros(
|
||||
[batch, model_config.linear_num_value_heads, model_config.linear_key_head_dim, model_config.linear_value_head_dim],
|
||||
device=device, dtype=torch.float32
|
||||
)
|
||||
conv_dim = model_config.linear_num_key_heads * model_config.linear_key_head_dim * 2 + model_config.linear_num_value_heads * model_config.linear_value_head_dim
|
||||
conv_state = torch.zeros(
|
||||
[batch, conv_dim, model_config.conv_kernel_size - 1],
|
||||
device=device, dtype=execution_dtype
|
||||
)
|
||||
past_key_values.append((recurrent_state, conv_state, 0))
|
||||
else:
|
||||
past_key_values.append((
|
||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||
0
|
||||
))
|
||||
return past_key_values
|
||||
|
||||
# Tokenizer and Text Encoder Wrappers
|
||||
|
||||
class Qwen35Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, embedding_size=2048, embedding_key="qwen35_2b"):
|
||||
from transformers import Qwen2Tokenizer
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen35_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=Qwen2Tokenizer,
|
||||
has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=248044, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, model_type="qwen35_2b"):
|
||||
embedding_size = QWEN35_MODELS.get(model_type, {}).get("hidden_size", 2048)
|
||||
tokenizer = lambda *a, **kw: Qwen35Tokenizer(*a, **kw, embedding_size=embedding_size, embedding_key=model_type)
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=model_type, tokenizer=tokenizer)
|
||||
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.llama_template_images = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs):
|
||||
image = kwargs.get("image", None)
|
||||
if image is not None and len(images) == 0:
|
||||
images = [image]
|
||||
|
||||
skip_template = False
|
||||
if text.startswith('<|im_start|>'):
|
||||
skip_template = True
|
||||
if prevent_empty_text and text == '':
|
||||
text = ' '
|
||||
|
||||
if skip_template:
|
||||
llama_text = text
|
||||
else:
|
||||
if llama_template is None:
|
||||
if len(images) > 0:
|
||||
llama_text = self.llama_template_images.format(text)
|
||||
else:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
if not thinking:
|
||||
llama_text += "<think>\n</think>\n"
|
||||
|
||||
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||
key_name = next(iter(tokens))
|
||||
embed_count = 0
|
||||
qwen_tokens = tokens[key_name]
|
||||
for r in qwen_tokens:
|
||||
for i in range(len(r)):
|
||||
if r[i][0] == 248056: # <|image_pad|>
|
||||
if len(images) > embed_count:
|
||||
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
|
||||
embed_count += 1
|
||||
return tokens
|
||||
|
||||
|
||||
class Qwen35ClipModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}, model_type="qwen35_2b"):
|
||||
class Qwen35_(Qwen35):
|
||||
pass
|
||||
Qwen35_.model_type = model_type
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={},
|
||||
dtype=dtype, special_tokens={"pad": 248044}, layer_norm_hidden_state=False,
|
||||
model_class=Qwen35_, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class Qwen35TEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, model_type="qwen35_2b"):
|
||||
clip_model = lambda **kw: Qwen35ClipModel(**kw, model_type=model_type)
|
||||
super().__init__(device=device, dtype=dtype, name=model_type, clip_model=clip_model, model_options=model_options)
|
||||
|
||||
|
||||
def tokenizer(model_type="qwen35_2b"):
|
||||
class Qwen35ImageTokenizer_(Qwen35ImageTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type=model_type)
|
||||
return Qwen35ImageTokenizer_
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen35_2b"):
|
||||
class Qwen35TEModel_(Qwen35TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, model_type=model_type)
|
||||
return Qwen35TEModel_
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@@ -145,20 +145,7 @@ class ReveImageCreateNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["upscale", "upscale.upscale_factor"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$factor := $lookup(widgets, "upscale.upscale_factor");
|
||||
$fmt := {"approximate": true, "note": "(base)"};
|
||||
widgets.upscale = "enabled" ? (
|
||||
$factor = 4 ? {"type": "usd", "usd": 0.0762, "format": $fmt}
|
||||
: $factor = 3 ? {"type": "usd", "usd": 0.0591, "format": $fmt}
|
||||
: {"type": "usd", "usd": 0.0457, "format": $fmt}
|
||||
) : {"type": "usd", "usd": 0.03432, "format": $fmt}
|
||||
)
|
||||
""",
|
||||
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -238,21 +225,13 @@ class ReveImageEditNode(IO.ComfyNode):
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model", "upscale", "upscale.upscale_factor"],
|
||||
widgets=["model"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$fmt := {"approximate": true, "note": "(base)"};
|
||||
$isFast := $contains(widgets.model, "fast");
|
||||
$enabled := widgets.upscale = "enabled";
|
||||
$factor := $lookup(widgets, "upscale.upscale_factor");
|
||||
$isFast
|
||||
? {"type": "usd", "usd": 0.01001, "format": $fmt}
|
||||
: $enabled ? (
|
||||
$factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt}
|
||||
: $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt}
|
||||
: {"type": "usd", "usd": 0.0686, "format": $fmt}
|
||||
) : {"type": "usd", "usd": 0.0572, "format": $fmt}
|
||||
$base := $isFast ? 0.01001 : 0.0572;
|
||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@@ -348,21 +327,13 @@ class ReveImageRemixNode(IO.ComfyNode):
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model", "upscale", "upscale.upscale_factor"],
|
||||
widgets=["model"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$fmt := {"approximate": true, "note": "(base)"};
|
||||
$isFast := $contains(widgets.model, "fast");
|
||||
$enabled := widgets.upscale = "enabled";
|
||||
$factor := $lookup(widgets, "upscale.upscale_factor");
|
||||
$isFast
|
||||
? {"type": "usd", "usd": 0.01001, "format": $fmt}
|
||||
: $enabled ? (
|
||||
$factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt}
|
||||
: $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt}
|
||||
: {"type": "usd", "usd": 0.0686, "format": $fmt}
|
||||
) : {"type": "usd", "usd": 0.0572, "format": $fmt}
|
||||
$base := $isFast ? 0.01001 : 0.0572;
|
||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
|
||||
@@ -38,7 +38,6 @@ from comfy_api_nodes.util import (
|
||||
UPSCALER_MODELS_MAP = {
|
||||
"Starlight (Astra) Fast": "slf-1",
|
||||
"Starlight (Astra) Creative": "slc-1",
|
||||
"Starlight Precise 2.5": "slp-2.5",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,67 +1,85 @@
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import ctypes
|
||||
import logging
|
||||
import ctypes.util
|
||||
import importlib.util
|
||||
from typing import TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import nodes
|
||||
import comfy_angle
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from typing_extensions import override
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _preload_angle():
|
||||
egl_path = comfy_angle.get_egl_path()
|
||||
gles_path = comfy_angle.get_glesv2_path()
|
||||
def _check_opengl_availability():
|
||||
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
|
||||
logger.debug("_check_opengl_availability: starting")
|
||||
missing = []
|
||||
|
||||
if sys.platform == "win32":
|
||||
angle_dir = comfy_angle.get_lib_dir()
|
||||
os.add_dll_directory(angle_dir)
|
||||
os.environ["PATH"] = angle_dir + os.pathsep + os.environ.get("PATH", "")
|
||||
# Check Python packages (using find_spec to avoid importing)
|
||||
logger.debug("_check_opengl_availability: checking for glfw package")
|
||||
if importlib.util.find_spec("glfw") is None:
|
||||
missing.append("glfw")
|
||||
|
||||
mode = 0 if sys.platform == "win32" else ctypes.RTLD_GLOBAL
|
||||
ctypes.CDLL(str(egl_path), mode=mode)
|
||||
ctypes.CDLL(str(gles_path), mode=mode)
|
||||
logger.debug("_check_opengl_availability: checking for OpenGL package")
|
||||
if importlib.util.find_spec("OpenGL") is None:
|
||||
missing.append("PyOpenGL")
|
||||
|
||||
if missing:
|
||||
raise RuntimeError(
|
||||
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
|
||||
)
|
||||
|
||||
# On Linux without display, check if headless backends are available
|
||||
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
|
||||
if sys.platform.startswith("linux"):
|
||||
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
|
||||
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
|
||||
if not has_display:
|
||||
# Check for EGL or OSMesa libraries
|
||||
logger.debug("_check_opengl_availability: checking for EGL library")
|
||||
has_egl = ctypes.util.find_library("EGL")
|
||||
logger.debug("_check_opengl_availability: checking for OSMesa library")
|
||||
has_osmesa = ctypes.util.find_library("OSMesa")
|
||||
|
||||
# Error disabled for CI as it fails this check
|
||||
# if not has_egl and not has_osmesa:
|
||||
# raise RuntimeError(
|
||||
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
|
||||
# "See error below for installation instructions."
|
||||
# )
|
||||
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
|
||||
|
||||
logger.debug("_check_opengl_availability: completed")
|
||||
|
||||
|
||||
# Pre-load ANGLE *before* any PyOpenGL import so that the EGL platform
|
||||
# plugin picks up ANGLE's libEGL / libGLESv2 instead of system libs.
|
||||
_preload_angle()
|
||||
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
||||
# Run early check at import time
|
||||
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
|
||||
_check_opengl_availability()
|
||||
|
||||
import OpenGL
|
||||
OpenGL.USE_ACCELERATE = False
|
||||
# OpenGL modules - initialized lazily when context is created
|
||||
gl = None
|
||||
glfw = None
|
||||
EGL = None
|
||||
|
||||
|
||||
def _patch_find_library():
|
||||
"""PyOpenGL's EGL platform looks for 'EGL' and 'GLESv2' by short name
|
||||
via ctypes.util.find_library, but ANGLE ships as 'libEGL' and
|
||||
'libGLESv2'. Patch find_library to return the full ANGLE paths so
|
||||
PyOpenGL loads the same libraries we pre-loaded."""
|
||||
if sys.platform == "linux":
|
||||
return
|
||||
import ctypes.util
|
||||
_orig = ctypes.util.find_library
|
||||
def _patched(name):
|
||||
if name == 'EGL':
|
||||
return comfy_angle.get_egl_path()
|
||||
if name == 'GLESv2':
|
||||
return comfy_angle.get_glesv2_path()
|
||||
return _orig(name)
|
||||
ctypes.util.find_library = _patched
|
||||
def _import_opengl():
|
||||
"""Import OpenGL module. Called after context is created."""
|
||||
global gl
|
||||
if gl is None:
|
||||
logger.debug("_import_opengl: importing OpenGL.GL")
|
||||
import OpenGL.GL as _gl
|
||||
gl = _gl
|
||||
logger.debug("_import_opengl: import completed")
|
||||
return gl
|
||||
|
||||
|
||||
_patch_find_library()
|
||||
|
||||
from OpenGL import EGL
|
||||
from OpenGL import GLES3 as gl
|
||||
|
||||
class SizeModeInput(TypedDict):
|
||||
size_mode: str
|
||||
width: int
|
||||
@@ -69,9 +87,7 @@ class SizeModeInput(TypedDict):
|
||||
|
||||
|
||||
MAX_IMAGES = 5 # u_image0-4
|
||||
MAX_UNIFORMS = 20 # u_float0-19, u_int0-19
|
||||
MAX_BOOLS = 10 # u_bool0-9
|
||||
MAX_CURVES = 4 # u_curve0-3 (1D LUT textures)
|
||||
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
|
||||
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||
|
||||
# Vertex shader using gl_VertexID trick - no VBO needed.
|
||||
@@ -84,7 +100,7 @@ MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||
# (-1,-1)---(3,-1)
|
||||
#
|
||||
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
|
||||
VERTEX_SHADER = """#version 300 es
|
||||
VERTEX_SHADER = """#version 330 core
|
||||
out vec2 v_texCoord;
|
||||
void main() {
|
||||
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
|
||||
@@ -108,21 +124,14 @@ void main() {
|
||||
"""
|
||||
|
||||
|
||||
|
||||
def _egl_attribs(*values):
|
||||
"""Build an EGL_NONE-terminated EGLint attribute array."""
|
||||
vals = list(values) + [EGL.EGL_NONE]
|
||||
return (ctypes.c_int32 * len(vals))(*vals)
|
||||
|
||||
|
||||
def _gl_str(name):
|
||||
"""Get an OpenGL string parameter."""
|
||||
v = gl.glGetString(name)
|
||||
if not v:
|
||||
return "Unknown"
|
||||
if isinstance(v, bytes):
|
||||
return v.decode(errors="replace")
|
||||
return ctypes.string_at(v).decode(errors="replace")
|
||||
def _convert_es_to_desktop(source: str) -> str:
|
||||
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
|
||||
# Remove any existing #version directive
|
||||
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
|
||||
# Remove precision qualifiers (not needed in desktop GLSL)
|
||||
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
|
||||
# Prepend desktop GLSL version
|
||||
return "#version 330 core\n" + source
|
||||
|
||||
|
||||
def _detect_output_count(source: str) -> int:
|
||||
@@ -148,8 +157,163 @@ def _detect_pass_count(source: str) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def _init_glfw():
|
||||
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_glfw: starting")
|
||||
# On macOS, glfw.init() must be called from main thread or it hangs forever
|
||||
if sys.platform == "darwin":
|
||||
logger.debug("_init_glfw: skipping on macOS")
|
||||
raise RuntimeError("GLFW backend not supported on macOS")
|
||||
|
||||
logger.debug("_init_glfw: importing glfw module")
|
||||
import glfw as _glfw
|
||||
|
||||
logger.debug("_init_glfw: calling glfw.init()")
|
||||
if not _glfw.init():
|
||||
raise RuntimeError("glfw.init() failed")
|
||||
|
||||
try:
|
||||
logger.debug("_init_glfw: setting window hints")
|
||||
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
|
||||
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
|
||||
|
||||
logger.debug("_init_glfw: calling create_window()")
|
||||
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
|
||||
if not window:
|
||||
raise RuntimeError("glfw.create_window() failed")
|
||||
|
||||
logger.debug("_init_glfw: calling make_context_current()")
|
||||
_glfw.make_context_current(window)
|
||||
logger.debug("_init_glfw: completed successfully")
|
||||
return window, _glfw
|
||||
except Exception:
|
||||
logger.debug("_init_glfw: failed, terminating glfw")
|
||||
_glfw.terminate()
|
||||
raise
|
||||
|
||||
|
||||
def _init_egl():
|
||||
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_egl: starting")
|
||||
from OpenGL import EGL as _EGL
|
||||
from OpenGL.EGL import (
|
||||
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
|
||||
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
|
||||
eglTerminate, eglDestroyContext, eglDestroySurface,
|
||||
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
|
||||
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
|
||||
)
|
||||
logger.debug("_init_egl: imports completed")
|
||||
|
||||
display = None
|
||||
context = None
|
||||
surface = None
|
||||
|
||||
try:
|
||||
logger.debug("_init_egl: calling eglGetDisplay()")
|
||||
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
|
||||
if display == _EGL.EGL_NO_DISPLAY:
|
||||
raise RuntimeError("eglGetDisplay() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglInitialize()")
|
||||
major, minor = _EGL.EGLint(), _EGL.EGLint()
|
||||
if not eglInitialize(display, major, minor):
|
||||
display = None # Not initialized, don't terminate
|
||||
raise RuntimeError("eglInitialize() failed")
|
||||
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
|
||||
|
||||
config_attribs = [
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
|
||||
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
|
||||
EGL_DEPTH_SIZE, 0, EGL_NONE
|
||||
]
|
||||
configs = (_EGL.EGLConfig * 1)()
|
||||
num_configs = _EGL.EGLint()
|
||||
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
config = configs[0]
|
||||
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
|
||||
|
||||
if not eglBindAPI(EGL_OPENGL_API):
|
||||
raise RuntimeError("eglBindAPI() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreateContext()")
|
||||
context_attribs = [
|
||||
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
|
||||
EGL_NONE
|
||||
]
|
||||
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
|
||||
if context == EGL_NO_CONTEXT:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
|
||||
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
|
||||
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
|
||||
if surface == _EGL.EGL_NO_SURFACE:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglMakeCurrent()")
|
||||
if not eglMakeCurrent(display, surface, surface, context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_egl: completed successfully")
|
||||
return display, context, surface, _EGL
|
||||
|
||||
except Exception:
|
||||
logger.debug("_init_egl: failed, cleaning up")
|
||||
# Clean up any resources on failure
|
||||
if surface is not None:
|
||||
eglDestroySurface(display, surface)
|
||||
if context is not None:
|
||||
eglDestroyContext(display, context)
|
||||
if display is not None:
|
||||
eglTerminate(display)
|
||||
raise
|
||||
|
||||
|
||||
def _init_osmesa():
|
||||
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
|
||||
import ctypes
|
||||
|
||||
logger.debug("_init_osmesa: starting")
|
||||
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
||||
|
||||
logger.debug("_init_osmesa: importing OpenGL.osmesa")
|
||||
from OpenGL import GL as _gl
|
||||
from OpenGL.osmesa import (
|
||||
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
|
||||
OSMESA_RGBA,
|
||||
)
|
||||
logger.debug("_init_osmesa: imports completed")
|
||||
|
||||
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
|
||||
if not ctx:
|
||||
raise RuntimeError("OSMesaCreateContextExt() failed")
|
||||
|
||||
width, height = 64, 64
|
||||
buffer = (ctypes.c_ubyte * (width * height * 4))()
|
||||
|
||||
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
|
||||
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
|
||||
OSMesaDestroyContext(ctx)
|
||||
raise RuntimeError("OSMesaMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_osmesa: completed successfully")
|
||||
return ctx, buffer
|
||||
|
||||
|
||||
class GLContext:
|
||||
"""Manages an OpenGL ES 3.0 context via EGL/ANGLE (singleton)."""
|
||||
"""Manages OpenGL context and resources for shader execution.
|
||||
|
||||
Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
@@ -161,111 +325,131 @@ class GLContext:
|
||||
|
||||
def __init__(self):
|
||||
if GLContext._initialized:
|
||||
logger.debug("GLContext.__init__: already initialized, skipping")
|
||||
return
|
||||
|
||||
logger.debug("GLContext.__init__: starting initialization")
|
||||
|
||||
global glfw, EGL
|
||||
|
||||
import time
|
||||
start = time.perf_counter()
|
||||
|
||||
self._display = None
|
||||
self._surface = None
|
||||
self._context = None
|
||||
self._backend = None
|
||||
self._window = None
|
||||
self._egl_display = None
|
||||
self._egl_context = None
|
||||
self._egl_surface = None
|
||||
self._osmesa_ctx = None
|
||||
self._osmesa_buffer = None
|
||||
self._vao = None
|
||||
|
||||
# Try backends in order: GLFW → EGL → OSMesa
|
||||
errors = []
|
||||
|
||||
logger.debug("GLContext.__init__: trying GLFW backend")
|
||||
try:
|
||||
self._display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY)
|
||||
if not self._display:
|
||||
raise RuntimeError("eglGetDisplay() returned no display")
|
||||
self._window, glfw = _init_glfw()
|
||||
self._backend = "glfw"
|
||||
logger.debug("GLContext.__init__: GLFW backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
|
||||
errors.append(("GLFW", e))
|
||||
|
||||
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
|
||||
if not EGL.eglInitialize(self._display, ctypes.byref(major), ctypes.byref(minor)):
|
||||
err = EGL.eglGetError()
|
||||
self._display = None
|
||||
raise RuntimeError(f"eglInitialize() failed (EGL error: 0x{err:04X})")
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying EGL backend")
|
||||
try:
|
||||
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
|
||||
self._backend = "egl"
|
||||
logger.debug("GLContext.__init__: EGL backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
|
||||
errors.append(("EGL", e))
|
||||
|
||||
if not EGL.eglBindAPI(EGL.EGL_OPENGL_ES_API):
|
||||
raise RuntimeError("eglBindAPI(EGL_OPENGL_ES_API) failed")
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying OSMesa backend")
|
||||
try:
|
||||
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
|
||||
self._backend = "osmesa"
|
||||
logger.debug("GLContext.__init__: OSMesa backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
|
||||
errors.append(("OSMesa", e))
|
||||
|
||||
config = EGL.EGLConfig()
|
||||
n_configs = ctypes.c_int32(0)
|
||||
if not EGL.eglChooseConfig(
|
||||
self._display,
|
||||
_egl_attribs(
|
||||
EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_ES3_BIT,
|
||||
EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
|
||||
EGL.EGL_RED_SIZE, 8, EGL.EGL_GREEN_SIZE, 8,
|
||||
EGL.EGL_BLUE_SIZE, 8, EGL.EGL_ALPHA_SIZE, 8,
|
||||
),
|
||||
ctypes.byref(config), 1, ctypes.byref(n_configs),
|
||||
) or n_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
if self._backend is None:
|
||||
if sys.platform == "win32":
|
||||
platform_help = (
|
||||
"Windows: Ensure GPU drivers are installed and display is available.\n"
|
||||
" CPU-only/headless mode is not supported on Windows."
|
||||
)
|
||||
elif sys.platform == "darwin":
|
||||
platform_help = (
|
||||
"macOS: GLFW is not supported.\n"
|
||||
" Install OSMesa via Homebrew: brew install mesa\n"
|
||||
" Then: pip install PyOpenGL PyOpenGL-accelerate"
|
||||
)
|
||||
else:
|
||||
platform_help = (
|
||||
"Linux: Install one of these backends:\n"
|
||||
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
|
||||
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
|
||||
" Headless (CPU): sudo apt install libosmesa6"
|
||||
)
|
||||
|
||||
self._surface = EGL.eglCreatePbufferSurface(
|
||||
self._display, config,
|
||||
_egl_attribs(EGL.EGL_WIDTH, 64, EGL.EGL_HEIGHT, 64),
|
||||
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
|
||||
raise RuntimeError(
|
||||
f"Failed to create OpenGL context.\n\n"
|
||||
f"Backend errors:\n{error_details}\n\n"
|
||||
f"{platform_help}"
|
||||
)
|
||||
if not self._surface:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
self._context = EGL.eglCreateContext(
|
||||
self._display, config, EGL.EGL_NO_CONTEXT,
|
||||
_egl_attribs(EGL.EGL_CONTEXT_CLIENT_VERSION, 3),
|
||||
)
|
||||
if not self._context:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
# Now import OpenGL.GL (after context is current)
|
||||
logger.debug("GLContext.__init__: importing OpenGL.GL")
|
||||
_import_opengl()
|
||||
|
||||
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
self._vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
except Exception:
|
||||
self._cleanup()
|
||||
raise
|
||||
# Create VAO (required for core profile, but OSMesa may use compat profile)
|
||||
logger.debug("GLContext.__init__: creating VAO")
|
||||
try:
|
||||
vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(vao)
|
||||
self._vao = vao # Only store after successful bind
|
||||
logger.debug("GLContext.__init__: VAO created successfully")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
|
||||
# OSMesa with older Mesa may not support VAOs
|
||||
# Clean up if we created but couldn't bind
|
||||
if vao:
|
||||
try:
|
||||
gl.glDeleteVertexArrays(1, [vao])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
elapsed = (time.perf_counter() - start) * 1000
|
||||
|
||||
renderer = _gl_str(gl.GL_RENDERER)
|
||||
vendor = _gl_str(gl.GL_VENDOR)
|
||||
version = _gl_str(gl.GL_VERSION)
|
||||
# Log device info
|
||||
renderer = gl.glGetString(gl.GL_RENDERER)
|
||||
vendor = gl.glGetString(gl.GL_VENDOR)
|
||||
version = gl.glGetString(gl.GL_VERSION)
|
||||
renderer = renderer.decode() if renderer else "Unknown"
|
||||
vendor = vendor.decode() if vendor else "Unknown"
|
||||
version = version.decode() if version else "Unknown"
|
||||
|
||||
GLContext._initialized = True
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms - {renderer} ({vendor}), GL {version}")
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
|
||||
|
||||
def make_current(self):
|
||||
EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context)
|
||||
if self._backend == "glfw":
|
||||
glfw.make_context_current(self._window)
|
||||
elif self._backend == "egl":
|
||||
from OpenGL.EGL import eglMakeCurrent
|
||||
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
|
||||
elif self._backend == "osmesa":
|
||||
from OpenGL.osmesa import OSMesaMakeCurrent
|
||||
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
|
||||
|
||||
if self._vao is not None:
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
def _cleanup(self):
|
||||
if not self._display:
|
||||
return
|
||||
try:
|
||||
if self._vao is not None:
|
||||
gl.glDeleteVertexArrays(1, [self._vao])
|
||||
self._vao = None
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
EGL.eglMakeCurrent(self._display, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._context:
|
||||
EGL.eglDestroyContext(self._display, self._context)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._surface:
|
||||
EGL.eglDestroySurface(self._display, self._surface)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
EGL.eglTerminate(self._display)
|
||||
except Exception:
|
||||
pass
|
||||
self._display = None
|
||||
|
||||
|
||||
def _compile_shader(source: str, shader_type: int) -> int:
|
||||
"""Compile a shader and return its ID."""
|
||||
@@ -273,10 +457,8 @@ def _compile_shader(source: str, shader_type: int) -> int:
|
||||
gl.glShaderSource(shader, source)
|
||||
gl.glCompileShader(shader)
|
||||
|
||||
if not gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS):
|
||||
error = gl.glGetShaderInfoLog(shader)
|
||||
if isinstance(error, bytes):
|
||||
error = error.decode(errors="replace")
|
||||
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetShaderInfoLog(shader).decode()
|
||||
gl.glDeleteShader(shader)
|
||||
raise RuntimeError(f"Shader compilation failed:\n{error}")
|
||||
|
||||
@@ -300,10 +482,8 @@ def _create_program(vertex_source: str, fragment_source: str) -> int:
|
||||
gl.glDeleteShader(vertex_shader)
|
||||
gl.glDeleteShader(fragment_shader)
|
||||
|
||||
if not gl.glGetProgramiv(program, gl.GL_LINK_STATUS):
|
||||
error = gl.glGetProgramInfoLog(program)
|
||||
if isinstance(error, bytes):
|
||||
error = error.decode(errors="replace")
|
||||
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetProgramInfoLog(program).decode()
|
||||
gl.glDeleteProgram(program)
|
||||
raise RuntimeError(f"Program linking failed:\n{error}")
|
||||
|
||||
@@ -317,8 +497,6 @@ def _render_shader_batch(
|
||||
image_batches: list[list[np.ndarray]],
|
||||
floats: list[float],
|
||||
ints: list[int],
|
||||
bools: list[bool] | None = None,
|
||||
curves: list[np.ndarray] | None = None,
|
||||
) -> list[list[np.ndarray]]:
|
||||
"""
|
||||
Render a fragment shader for multiple batches efficiently.
|
||||
@@ -333,8 +511,6 @@ def _render_shader_batch(
|
||||
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
|
||||
floats: List of float uniforms
|
||||
ints: List of int uniforms
|
||||
bools: List of bool uniforms (passed as int 0/1 to GLSL bool uniforms)
|
||||
curves: List of 1D LUT arrays (float32) of arbitrary size for u_curve0-N
|
||||
|
||||
Returns:
|
||||
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
|
||||
@@ -348,23 +524,20 @@ def _render_shader_batch(
|
||||
ctx = GLContext()
|
||||
ctx.make_current()
|
||||
|
||||
# Convert from GLSL ES to desktop GLSL 330
|
||||
fragment_source = _convert_es_to_desktop(fragment_code)
|
||||
|
||||
# Detect how many outputs the shader actually uses
|
||||
num_outputs = _detect_output_count(fragment_code)
|
||||
|
||||
# Detect multi-pass rendering
|
||||
num_passes = _detect_pass_count(fragment_code)
|
||||
|
||||
if bools is None:
|
||||
bools = []
|
||||
if curves is None:
|
||||
curves = []
|
||||
|
||||
# Track resources for cleanup
|
||||
program = None
|
||||
fbo = None
|
||||
output_textures = []
|
||||
input_textures = []
|
||||
curve_textures = []
|
||||
ping_pong_textures = []
|
||||
ping_pong_fbos = []
|
||||
|
||||
@@ -373,9 +546,9 @@ def _render_shader_batch(
|
||||
try:
|
||||
# Compile shaders (once for all batches)
|
||||
try:
|
||||
program = _create_program(VERTEX_SHADER, fragment_code)
|
||||
program = _create_program(VERTEX_SHADER, fragment_source)
|
||||
except RuntimeError:
|
||||
logger.error(f"Fragment shader:\n{fragment_code}")
|
||||
logger.error(f"Fragment shader:\n{fragment_source}")
|
||||
raise
|
||||
|
||||
gl.glUseProgram(program)
|
||||
@@ -451,28 +624,6 @@ def _render_shader_batch(
|
||||
if loc >= 0:
|
||||
gl.glUniform1i(loc, v)
|
||||
|
||||
for i, v in enumerate(bools):
|
||||
loc = gl.glGetUniformLocation(program, f"u_bool{i}")
|
||||
if loc >= 0:
|
||||
gl.glUniform1i(loc, 1 if v else 0)
|
||||
|
||||
# Create 1D LUT textures for curves (bound after image texture units)
|
||||
for i, lut in enumerate(curves):
|
||||
tex = gl.glGenTextures(1)
|
||||
curve_textures.append(tex)
|
||||
unit = MAX_IMAGES + i
|
||||
gl.glActiveTexture(gl.GL_TEXTURE0 + unit)
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_R32F, len(lut), 1, 0, gl.GL_RED, gl.GL_FLOAT, lut)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
|
||||
|
||||
loc = gl.glGetUniformLocation(program, f"u_curve{i}")
|
||||
if loc >= 0:
|
||||
gl.glUniform1i(loc, unit)
|
||||
|
||||
# Get u_pass uniform location for multi-pass
|
||||
pass_loc = gl.glGetUniformLocation(program, "u_pass")
|
||||
|
||||
@@ -538,13 +689,13 @@ def _render_shader_batch(
|
||||
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
|
||||
|
||||
# Read back outputs for this batch
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
||||
# (glGetTexImage is synchronous, implicitly waits for rendering)
|
||||
batch_outputs = []
|
||||
for i in range(num_outputs):
|
||||
gl.glReadBuffer(gl.GL_COLOR_ATTACHMENT0 + i)
|
||||
buf = np.empty((height, width, 4), dtype=np.float32)
|
||||
gl.glReadPixels(0, 0, width, height, gl.GL_RGBA, gl.GL_FLOAT, buf)
|
||||
batch_outputs.append(buf[::-1, :, :].copy())
|
||||
for tex in output_textures:
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
|
||||
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
|
||||
batch_outputs.append(img[::-1, :, :].copy())
|
||||
|
||||
# Pad with black images for unused outputs
|
||||
black_img = np.zeros((height, width, 4), dtype=np.float32)
|
||||
@@ -565,18 +716,16 @@ def _render_shader_batch(
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
|
||||
gl.glUseProgram(0)
|
||||
|
||||
if input_textures:
|
||||
gl.glDeleteTextures(len(input_textures), input_textures)
|
||||
if curve_textures:
|
||||
gl.glDeleteTextures(len(curve_textures), curve_textures)
|
||||
if output_textures:
|
||||
gl.glDeleteTextures(len(output_textures), output_textures)
|
||||
if ping_pong_textures:
|
||||
gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
|
||||
for tex in input_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in output_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in ping_pong_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
if fbo is not None:
|
||||
gl.glDeleteFramebuffers(1, [fbo])
|
||||
if ping_pong_fbos:
|
||||
gl.glDeleteFramebuffers(len(ping_pong_fbos), ping_pong_fbos)
|
||||
for pp_fbo in ping_pong_fbos:
|
||||
gl.glDeleteFramebuffers(1, [pp_fbo])
|
||||
if program is not None:
|
||||
gl.glDeleteProgram(program)
|
||||
|
||||
@@ -605,20 +754,6 @@ class GLSLShader(io.ComfyNode):
|
||||
max=MAX_UNIFORMS,
|
||||
)
|
||||
|
||||
bool_template = io.Autogrow.TemplatePrefix(
|
||||
io.Boolean.Input("bool", default=False),
|
||||
prefix="u_bool",
|
||||
min=0,
|
||||
max=MAX_BOOLS,
|
||||
)
|
||||
|
||||
curve_template = io.Autogrow.TemplatePrefix(
|
||||
io.Curve.Input("curve"),
|
||||
prefix="u_curve",
|
||||
min=0,
|
||||
max=MAX_CURVES,
|
||||
)
|
||||
|
||||
return io.Schema(
|
||||
node_id="GLSLShader",
|
||||
display_name="GLSL Shader",
|
||||
@@ -627,7 +762,6 @@ class GLSLShader(io.ComfyNode):
|
||||
"Apply GLSL ES fragment shaders to images. "
|
||||
"u_resolution (vec2) is always available."
|
||||
),
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.String.Input(
|
||||
"fragment_shader",
|
||||
@@ -662,8 +796,6 @@ class GLSLShader(io.ComfyNode):
|
||||
io.Autogrow.Input("images", template=image_template, tooltip=f"Images are available as u_image0-{MAX_IMAGES-1} (sampler2D) in the shader code"),
|
||||
io.Autogrow.Input("floats", template=float_template, tooltip=f"Floats are available as u_float0-{MAX_UNIFORMS-1} in the shader code"),
|
||||
io.Autogrow.Input("ints", template=int_template, tooltip=f"Ints are available as u_int0-{MAX_UNIFORMS-1} in the shader code"),
|
||||
io.Autogrow.Input("bools", template=bool_template, tooltip=f"Booleans are available as u_bool0-{MAX_BOOLS-1} (bool) in the shader code"),
|
||||
io.Autogrow.Input("curves", template=curve_template, tooltip=f"Curves are available as u_curve0-{MAX_CURVES-1} (sampler2D, 1D LUT) in the shader code. Sample with texture(u_curve0, vec2(x, 0.5)).r"),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="IMAGE0", tooltip="Available via layout(location = 0) out vec4 fragColor0 in the shader code"),
|
||||
@@ -681,19 +813,13 @@ class GLSLShader(io.ComfyNode):
|
||||
images: io.Autogrow.Type,
|
||||
floats: io.Autogrow.Type = None,
|
||||
ints: io.Autogrow.Type = None,
|
||||
bools: io.Autogrow.Type = None,
|
||||
curves: io.Autogrow.Type = None,
|
||||
**kwargs,
|
||||
) -> io.NodeOutput:
|
||||
|
||||
image_list = [v for v in images.values() if v is not None]
|
||||
float_list = (
|
||||
[v if v is not None else 0.0 for v in floats.values()] if floats else []
|
||||
)
|
||||
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
|
||||
bool_list = [v if v is not None else False for v in bools.values()] if bools else []
|
||||
|
||||
curve_luts = [v.to_lut().astype(np.float32) for v in curves.values() if v is not None] if curves else []
|
||||
|
||||
if not image_list:
|
||||
raise ValueError("At least one input image is required")
|
||||
@@ -720,8 +846,6 @@ class GLSLShader(io.ComfyNode):
|
||||
image_batches,
|
||||
float_list,
|
||||
int_list,
|
||||
bool_list,
|
||||
curve_luts,
|
||||
)
|
||||
|
||||
# Collect outputs into tensors
|
||||
|
||||
@@ -44,13 +44,8 @@ class NumberConvertNode(io.ComfyNode):
|
||||
def execute(cls, value) -> io.NodeOutput:
|
||||
if isinstance(value, bool):
|
||||
float_val = 1.0 if value else 0.0
|
||||
int_val = 1 if value else 0
|
||||
elif isinstance(value, int):
|
||||
elif isinstance(value, (int, float)):
|
||||
float_val = float(value)
|
||||
int_val = value
|
||||
elif isinstance(value, float):
|
||||
float_val = value
|
||||
int_val = int(value)
|
||||
elif isinstance(value, str):
|
||||
text = value.strip()
|
||||
if not text:
|
||||
@@ -61,14 +56,6 @@ class NumberConvertNode(io.ComfyNode):
|
||||
raise ValueError(
|
||||
f"Cannot convert string to number: {value!r}"
|
||||
) from None
|
||||
if not math.isfinite(float_val):
|
||||
raise ValueError(
|
||||
f"Cannot convert non-finite value to number: {float_val}"
|
||||
)
|
||||
try:
|
||||
int_val = int(text)
|
||||
except ValueError:
|
||||
int_val = int(float_val)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unsupported input type: {type(value).__name__}"
|
||||
@@ -79,7 +66,7 @@ class NumberConvertNode(io.ComfyNode):
|
||||
f"Cannot convert non-finite value to number: {float_val}"
|
||||
)
|
||||
|
||||
return io.NodeOutput(float_val, int_val)
|
||||
return io.NodeOutput(float_val, int(float_val))
|
||||
|
||||
|
||||
class NumberConvertExtension(ComfyExtension):
|
||||
|
||||
@@ -67,11 +67,11 @@ class Blend(io.ComfyNode):
|
||||
def g(cls, x):
|
||||
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
||||
|
||||
def gaussian_kernel(kernel_size: int, sigma: float, device=None, dtype=torch.float32):
|
||||
def gaussian_kernel(kernel_size: int, sigma: float, device=None):
|
||||
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij")
|
||||
d = torch.sqrt(x * x + y * y)
|
||||
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
||||
return (g / g.sum()).to(dtype)
|
||||
return g / g.sum()
|
||||
|
||||
class Blur(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -99,7 +99,7 @@ class Blur(io.ComfyNode):
|
||||
batch_size, height, width, channels = image.shape
|
||||
|
||||
kernel_size = blur_radius * 2 + 1
|
||||
kernel = gaussian_kernel(kernel_size, sigma, device=image.device, dtype=image.dtype).repeat(channels, 1, 1).unsqueeze(1)
|
||||
kernel = gaussian_kernel(kernel_size, sigma, device=image.device).repeat(channels, 1, 1).unsqueeze(1)
|
||||
|
||||
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||
padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
|
||||
@@ -200,7 +200,7 @@ class Sharpen(io.ComfyNode):
|
||||
image = image.to(comfy.model_management.get_torch_device())
|
||||
|
||||
kernel_size = sharpen_radius * 2 + 1
|
||||
kernel = gaussian_kernel(kernel_size, sigma, device=image.device, dtype=image.dtype) * -(alpha*10)
|
||||
kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
|
||||
kernel = kernel.to(dtype=image.dtype)
|
||||
center = kernel_size // 2
|
||||
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
||||
|
||||
@@ -15,7 +15,6 @@ class TextGenerate(io.ComfyNode):
|
||||
io.Float.Input("min_p", default=0.05, min=0.0, max=1.0, step=0.01),
|
||||
io.Float.Input("repetition_penalty", default=1.05, min=0.0, max=5.0, step=0.01),
|
||||
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
|
||||
io.Float.Input("presence_penalty", optional=True, default=0.0, min=0.0, max=5.0, step=0.01),
|
||||
]
|
||||
),
|
||||
io.DynamicCombo.Option(
|
||||
@@ -26,7 +25,7 @@ class TextGenerate(io.ComfyNode):
|
||||
|
||||
return io.Schema(
|
||||
node_id="TextGenerate",
|
||||
category="textgen",
|
||||
category="textgen/",
|
||||
search_aliases=["LLM", "gemma"],
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
@@ -34,7 +33,6 @@ class TextGenerate(io.ComfyNode):
|
||||
io.Image.Input("image", optional=True),
|
||||
io.Int.Input("max_length", default=256, min=1, max=2048),
|
||||
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
||||
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(display_name="generated_text"),
|
||||
@@ -42,9 +40,9 @@ class TextGenerate(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
|
||||
|
||||
tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1, thinking=thinking)
|
||||
tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1)
|
||||
|
||||
# Get sampling parameters from dynamic combo
|
||||
do_sample = sampling_mode.get("sampling_mode") == "on"
|
||||
@@ -54,7 +52,6 @@ class TextGenerate(io.ComfyNode):
|
||||
min_p = sampling_mode.get("min_p", 0.0)
|
||||
seed = sampling_mode.get("seed", None)
|
||||
repetition_penalty = sampling_mode.get("repetition_penalty", 1.0)
|
||||
presence_penalty = sampling_mode.get("presence_penalty", 0.0)
|
||||
|
||||
generated_ids = clip.generate(
|
||||
tokens,
|
||||
@@ -65,7 +62,6 @@ class TextGenerate(io.ComfyNode):
|
||||
top_p=top_p,
|
||||
min_p=min_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed
|
||||
)
|
||||
|
||||
@@ -160,12 +156,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
|
||||
if image is None:
|
||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||
else:
|
||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking)
|
||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image)
|
||||
|
||||
|
||||
class TextgenExtension(ComfyExtension):
|
||||
|
||||
9
main.py
9
main.py
@@ -139,7 +139,16 @@ def execute_prestartup_script():
|
||||
spec.loader.exec_module(module)
|
||||
return True
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
|
||||
from nodes import NODE_STARTUP_ERRORS, get_module_name
|
||||
node_module_name = get_module_name(os.path.dirname(script_path))
|
||||
NODE_STARTUP_ERRORS[node_module_name] = {
|
||||
"module_path": os.path.dirname(script_path),
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc(),
|
||||
"phase": "prestartup",
|
||||
}
|
||||
return False
|
||||
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
|
||||
@@ -1 +1 @@
|
||||
comfyui_manager==4.1
|
||||
comfyui_manager==4.1b8
|
||||
|
||||
10
nodes.py
10
nodes.py
@@ -2181,6 +2181,9 @@ EXTENSION_WEB_DIRS = {}
|
||||
# Dictionary of successfully loaded module names and associated directories.
|
||||
LOADED_MODULE_DIRS = {}
|
||||
|
||||
# Dictionary of custom node startup errors, keyed by module name.
|
||||
NODE_STARTUP_ERRORS: dict[str, dict] = {}
|
||||
|
||||
|
||||
def get_module_name(module_path: str) -> str:
|
||||
"""
|
||||
@@ -2298,6 +2301,13 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
||||
except Exception as e:
|
||||
logging.warning(traceback.format_exc())
|
||||
logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
|
||||
module_name = get_module_name(module_path)
|
||||
NODE_STARTUP_ERRORS[module_name] = {
|
||||
"module_path": module_path,
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc(),
|
||||
"phase": "import",
|
||||
}
|
||||
return False
|
||||
|
||||
async def init_external_custom_nodes():
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.42.8
|
||||
comfyui-workflow-templates==0.9.38
|
||||
comfyui-workflow-templates==0.9.36
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
torchsde
|
||||
@@ -33,5 +33,5 @@ kornia>=0.7.1
|
||||
spandrel
|
||||
pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
PyOpenGL>=3.1.8
|
||||
comfy-angle
|
||||
PyOpenGL
|
||||
glfw
|
||||
|
||||
@@ -753,6 +753,10 @@ class PromptServer():
|
||||
out[node_class] = node_info(node_class)
|
||||
return web.json_response(out)
|
||||
|
||||
@routes.get("/custom_node_startup_errors")
|
||||
async def get_custom_node_startup_errors(request):
|
||||
return web.json_response(nodes.NODE_STARTUP_ERRORS)
|
||||
|
||||
@routes.get("/api/jobs")
|
||||
async def get_jobs(request):
|
||||
"""List all jobs with filtering, sorting, and pagination.
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
"""Tests for path_utils – asset category resolution."""
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.assets.services.path_utils import get_asset_category_and_relative_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_dirs():
|
||||
"""Create temporary input, output, and temp directories."""
|
||||
with tempfile.TemporaryDirectory() as root:
|
||||
root_path = Path(root)
|
||||
input_dir = root_path / "input"
|
||||
output_dir = root_path / "output"
|
||||
temp_dir = root_path / "temp"
|
||||
models_dir = root_path / "models" / "checkpoints"
|
||||
for d in (input_dir, output_dir, temp_dir, models_dir):
|
||||
d.mkdir(parents=True)
|
||||
|
||||
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
|
||||
mock_fp.get_input_directory.return_value = str(input_dir)
|
||||
mock_fp.get_output_directory.return_value = str(output_dir)
|
||||
mock_fp.get_temp_directory.return_value = str(temp_dir)
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[("checkpoints", [str(models_dir)])],
|
||||
):
|
||||
yield {
|
||||
"input": input_dir,
|
||||
"output": output_dir,
|
||||
"temp": temp_dir,
|
||||
"models": models_dir,
|
||||
}
|
||||
|
||||
|
||||
class TestGetAssetCategoryAndRelativePath:
|
||||
def test_input_file(self, fake_dirs):
|
||||
f = fake_dirs["input"] / "photo.png"
|
||||
f.touch()
|
||||
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||
assert cat == "input"
|
||||
assert rel == "photo.png"
|
||||
|
||||
def test_output_file(self, fake_dirs):
|
||||
f = fake_dirs["output"] / "result.png"
|
||||
f.touch()
|
||||
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||
assert cat == "output"
|
||||
assert rel == "result.png"
|
||||
|
||||
def test_temp_file(self, fake_dirs):
|
||||
"""Regression: temp files must be categorised, not raise ValueError."""
|
||||
f = fake_dirs["temp"] / "GLSLShader_output_00004_.png"
|
||||
f.touch()
|
||||
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||
assert cat == "temp"
|
||||
assert rel == "GLSLShader_output_00004_.png"
|
||||
|
||||
def test_temp_file_in_subfolder(self, fake_dirs):
|
||||
sub = fake_dirs["temp"] / "sub"
|
||||
sub.mkdir()
|
||||
f = sub / "ComfyUI_temp_tczip_00004_.png"
|
||||
f.touch()
|
||||
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||
assert cat == "temp"
|
||||
assert os.path.normpath(rel) == os.path.normpath("sub/ComfyUI_temp_tczip_00004_.png")
|
||||
|
||||
def test_model_file(self, fake_dirs):
|
||||
f = fake_dirs["models"] / "model.safetensors"
|
||||
f.touch()
|
||||
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||
assert cat == "models"
|
||||
|
||||
def test_unknown_path_raises(self, fake_dirs):
|
||||
with pytest.raises(ValueError, match="not within"):
|
||||
get_asset_category_and_relative_path("/some/random/path.png")
|
||||
@@ -90,63 +90,6 @@ class TestNumberConvertExecute:
|
||||
assert result[0] == 1000.0
|
||||
assert result[1] == 1000
|
||||
|
||||
# --- Large number precision (string input) ---
|
||||
|
||||
def test_string_large_int_above_2_53(self):
|
||||
"""Text-to-int must not lose precision for integers beyond 2^53."""
|
||||
big = 2**53 + 1 # 9007199254740993
|
||||
result = self._exec(str(big))
|
||||
assert result[1] == big
|
||||
|
||||
def test_string_large_negative_int_above_2_53(self):
|
||||
big = -(2**53 + 1)
|
||||
result = self._exec(str(big))
|
||||
assert result[1] == big
|
||||
|
||||
def test_string_very_large_int(self):
|
||||
big = 2**63 + 42
|
||||
result = self._exec(str(big))
|
||||
assert result[1] == big
|
||||
|
||||
def test_string_large_int_float_output_is_float(self):
|
||||
"""FLOAT output is still a float (may lose precision, but must be float type)."""
|
||||
result = self._exec(str(2**53 + 1))
|
||||
assert isinstance(result[0], float)
|
||||
|
||||
# --- Large number precision (int input) ---
|
||||
|
||||
def test_int_large_above_2_53(self):
|
||||
"""Native int input must preserve its value in the INT output."""
|
||||
big = 2**53 + 1
|
||||
result = self._exec(big)
|
||||
assert result[1] == big
|
||||
|
||||
def test_int_large_negative_above_2_53(self):
|
||||
big = -(2**53 + 1)
|
||||
result = self._exec(big)
|
||||
assert result[1] == big
|
||||
|
||||
def test_int_very_large(self):
|
||||
big = 2**100
|
||||
result = self._exec(big)
|
||||
assert result[1] == big
|
||||
|
||||
# --- String decimal / scientific notation fallback ---
|
||||
|
||||
def test_string_decimal_still_truncates(self):
|
||||
"""Strings with decimal points fall back to int(float(...)) truncation."""
|
||||
result = self._exec("3.7")
|
||||
assert result[1] == 3
|
||||
|
||||
def test_string_negative_decimal_truncates(self):
|
||||
result = self._exec("-2.9")
|
||||
assert result[1] == -2
|
||||
|
||||
def test_string_scientific_large(self):
|
||||
result = self._exec("1e18")
|
||||
assert result[0] == 1e18
|
||||
assert result[1] == 10**18
|
||||
|
||||
# --- STRING error paths ---
|
||||
|
||||
def test_empty_string_raises(self):
|
||||
|
||||
@@ -24,7 +24,6 @@ def init_mime_types():
|
||||
# Web types (used by server.py for static file serving)
|
||||
mimetypes.add_type('application/javascript; charset=utf-8', '.js')
|
||||
mimetypes.add_type('image/webp', '.webp')
|
||||
mimetypes.add_type('image/svg+xml', '.svg')
|
||||
|
||||
# Model and data file types (used by asset scanning / metadata extraction)
|
||||
mimetypes.add_type("application/safetensors", ".safetensors")
|
||||
|
||||
Reference in New Issue
Block a user