Compare commits

..

18 Commits

Author SHA1 Message Date
bigcat88
92890ef01d feat(api-nodes-Tencent3D): allow smaller possible face_count; add uv_image output
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-03-28 14:36:06 +02:00
rattus
b353a7c863 Integrate RAM cache with model RAM management (#13173) 2026-03-27 21:34:16 -04:00
Terry Jia
3696c5bad6 Add has_intermediate_output flag for nodes with interactive UI (#13048) 2026-03-27 21:06:38 -04:00
comfyanonymous
3a56201da5 Allow flux conditioning without a pooled output. (#13198) 2026-03-27 20:36:26 -04:00
Alexander Piskun
6a2cdb817d fix(api-nodes-nanobana): raise error when not output image is present (#13167)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-03-27 12:11:41 -07:00
ComfyUI Wiki
85b7495135 chore: update workflow templates to v0.9.39 (#13196) 2026-03-27 10:13:02 -07:00
Jin Yi
225c52f6a4 fix: register image/svg+xml MIME type for .svg files (#13186)
The /view endpoint returns text/plain for .svg files on some platforms
because Python's mimetypes module does not always include SVG by default.
Explicitly register image/svg+xml so <img> tags can render SVGs correctly.

Amp-Thread-ID: https://ampcode.com/threads/T-019d2da7-6a64-726a-af91-bd9c44e7f43c
2026-03-26 22:13:29 -07:00
comfyanonymous
b1fdbeb9a7 Fix blur and sharpen nodes not working with fp16 intermediates. (#13181) 2026-03-26 22:18:16 -04:00
Terry Jia
1dc64f3526 feat: add curve inputs and raise uniform limit for GLSL shader node (#13158)
* feat: add curve inputs and raise uniform limit for GLSL shader node

* allow arbitrary size for curve
2026-03-26 21:45:05 -04:00
ComfyUI Wiki
359559c913 chore: update workflow templates to v0.9.38 (#13176) 2026-03-26 12:07:38 -07:00
Alexander Piskun
8165485a17 feat(api-nodes): added new Topaz model (#13175)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-03-26 12:02:04 -07:00
Jukka Seppänen
b0fd65e884 fix: regression in text generate with LTXAV model (#13170) 2026-03-26 09:55:05 -07:00
comfyanonymous
2a1f402601 Make Qwen 8B work with TextGenerate node. (#13160) 2026-03-25 23:21:44 -04:00
Luke Mino-Altherr
3eba2dcf2d fix(assets): recognize temp directory in asset category resolution (#13159) 2026-03-25 19:59:59 -07:00
Jukka Seppänen
404d7b9978 feat: Support Qwen3.5 text generation models (#12771) 2026-03-25 22:48:28 -04:00
Dante
6580a6bc01 fix(number-convert): preserve int precision for large numbers (#13147) 2026-03-25 18:06:34 -04:00
Dr.Lt.Data
3b15651bc6 bump manager version to 4.1 (#13156) 2026-03-25 16:49:29 -04:00
Alexander Piskun
a55835f10c fix(api-nodes): made Reve node price badges more precise (#13154)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-03-25 11:05:49 -07:00
41 changed files with 497396 additions and 127 deletions

View File

@@ -93,12 +93,13 @@ def compute_relative_filename(file_path: str) -> str | None:
def get_asset_category_and_relative_path(
file_path: str,
) -> tuple[Literal["input", "output", "models"], str]:
) -> tuple[Literal["input", "output", "temp", "models"], str]:
"""Determine which root category a file path belongs to.
Categories:
- 'input': under folder_paths.get_input_directory()
- 'output': under folder_paths.get_output_directory()
- 'temp': under folder_paths.get_temp_directory()
- 'models': under any base path from get_comfy_models_folders()
Returns:
@@ -129,7 +130,12 @@ def get_asset_category_and_relative_path(
if _check_is_within(fp_abs, output_base):
return "output", _compute_relative(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
# 3) temp
temp_base = os.path.abspath(folder_paths.get_temp_directory())
if _check_is_within(fp_abs, temp_base):
return "temp", _compute_relative(fp_abs, temp_base)
# 4) models (check deepest matching base to avoid ambiguity)
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
@@ -146,7 +152,7 @@ def get_asset_category_and_relative_path(
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(
f"Path is not within input, output, or configured model bases: {file_path}"
f"Path is not within input, output, temp, or configured model bases: {file_path}"
)

View File

@@ -0,0 +1,90 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform float u_float0;
uniform float u_float1;
uniform float u_float2;
uniform float u_float3;
uniform float u_float4;
uniform float u_float5;
uniform float u_float6;
uniform float u_float7;
uniform float u_float8;
uniform bool u_bool0;
in vec2 v_texCoord;
out vec4 fragColor;
vec3 rgb2hsl(vec3 c) {
float maxC = max(c.r, max(c.g, c.b));
float minC = min(c.r, min(c.g, c.b));
float l = (maxC + minC) * 0.5;
if (maxC == minC) return vec3(0.0, 0.0, l);
float d = maxC - minC;
float s = l > 0.5 ? d / (2.0 - maxC - minC) : d / (maxC + minC);
float h;
if (maxC == c.r) {
h = (c.g - c.b) / d + (c.g < c.b ? 6.0 : 0.0);
} else if (maxC == c.g) {
h = (c.b - c.r) / d + 2.0;
} else {
h = (c.r - c.g) / d + 4.0;
}
h /= 6.0;
return vec3(h, s, l);
}
float hue2rgb(float p, float q, float t) {
if (t < 0.0) t += 1.0;
if (t > 1.0) t -= 1.0;
if (t < 1.0 / 6.0) return p + (q - p) * 6.0 * t;
if (t < 1.0 / 2.0) return q;
if (t < 2.0 / 3.0) return p + (q - p) * (2.0 / 3.0 - t) * 6.0;
return p;
}
vec3 hsl2rgb(vec3 hsl) {
float h = hsl.x, s = hsl.y, l = hsl.z;
if (s == 0.0) return vec3(l);
float q = l < 0.5 ? l * (1.0 + s) : l + s - l * s;
float p = 2.0 * l - q;
return vec3(
hue2rgb(p, q, h + 1.0 / 3.0),
hue2rgb(p, q, h),
hue2rgb(p, q, h - 1.0 / 3.0)
);
}
void main() {
vec4 tex = texture(u_image0, v_texCoord);
vec3 color = tex.rgb;
vec3 shadows = vec3(u_float0, u_float1, u_float2) * 0.01;
vec3 midtones = vec3(u_float3, u_float4, u_float5) * 0.01;
vec3 highlights = vec3(u_float6, u_float7, u_float8) * 0.01;
float maxC = max(color.r, max(color.g, color.b));
float minC = min(color.r, min(color.g, color.b));
float lightness = (maxC + minC) * 0.5;
// GIMP weight curves: linear ramps with constants a=0.25, b=0.333, scale=0.7
const float a = 0.25;
const float b = 0.333;
const float scale = 0.7;
float sw = clamp((lightness - b) / -a + 0.5, 0.0, 1.0) * scale;
float mw = clamp((lightness - b) / a + 0.5, 0.0, 1.0) *
clamp((lightness + b - 1.0) / -a + 0.5, 0.0, 1.0) * scale;
float hw = clamp((lightness + b - 1.0) / a + 0.5, 0.0, 1.0) * scale;
color += sw * shadows + mw * midtones + hw * highlights;
if (u_bool0) {
vec3 hsl = rgb2hsl(clamp(color, 0.0, 1.0));
hsl.z = lightness;
color = hsl2rgb(hsl);
}
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
}

View File

@@ -0,0 +1,46 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform sampler2D u_curve0; // RGB master curve (256x1 LUT)
uniform sampler2D u_curve1; // Red channel curve
uniform sampler2D u_curve2; // Green channel curve
uniform sampler2D u_curve3; // Blue channel curve
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
// GIMP-compatible curve lookup with manual linear interpolation.
// Matches gimp_curve_map_value_inline() from gimpcurve-map.c:
// index = value * (n_samples - 1)
// f = fract(index)
// result = (1-f) * samples[floor] + f * samples[ceil]
//
// Uses texelFetch (NEAREST) to avoid GPU half-texel offset issues
// that occur with texture() + GL_LINEAR on small 256x1 LUTs.
float applyCurve(sampler2D curve, float value) {
value = clamp(value, 0.0, 1.0);
float pos = value * 255.0;
int lo = int(floor(pos));
int hi = min(lo + 1, 255);
float f = pos - float(lo);
float a = texelFetch(curve, ivec2(lo, 0), 0).r;
float b = texelFetch(curve, ivec2(hi, 0), 0).r;
return a + f * (b - a);
}
void main() {
vec4 color = texture(u_image0, v_texCoord);
// GIMP order: per-channel curves first, then RGB master curve.
// See gimp_curve_map_pixels() default case in gimpcurve-map.c:
// dest = colors_curve( channel_curve( src ) )
color.r = applyCurve(u_curve0, applyCurve(u_curve1, color.r));
color.g = applyCurve(u_curve0, applyCurve(u_curve2, color.g));
color.b = applyCurve(u_curve0, applyCurve(u_curve3, color.b));
fragColor0 = vec4(color.rgb, color.a);
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -110,11 +110,13 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
CACHE_RAM_AUTO_GB = -1.0
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")

View File

@@ -141,3 +141,17 @@ def interpret_gathered_like(tensors, gathered):
return dest_views
aimdo_enabled = False
extra_ram_release_callback = None
RAM_CACHE_HEADROOM = 0
def set_ram_cache_release_state(callback, headroom):
global extra_ram_release_callback
global RAM_CACHE_HEADROOM
extra_ram_release_callback = callback
RAM_CACHE_HEADROOM = max(0, int(headroom))
def extra_ram_release(target):
if extra_ram_release_callback is None:
return 0
return extra_ram_release_callback(target)

View File

@@ -890,7 +890,7 @@ class Flux(BaseModel):
return torch.cat((image, mask), dim=1)
def encode_adm(self, **kwargs):
return kwargs["pooled_output"]
return kwargs.get("pooled_output", None)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)

View File

@@ -669,7 +669,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i]
if shift_model.device == device:
if device is None or shift_model.device == device:
if shift_model not in keep_loaded and not shift_model.is_dead():
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
@@ -679,8 +679,8 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
i = x[-1]
memory_to_free = 1e32
pins_to_free = 1e32
if not DISABLE_SMART_MEMORY:
memory_to_free = memory_required - get_free_memory(device)
if not DISABLE_SMART_MEMORY or device is None:
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
pins_to_free = pins_required - get_free_ram()
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models
@@ -708,7 +708,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
if len(unloaded_model) > 0:
soft_empty_cache()
else:
elif device is not None:
if vram_state != VRAMState.HIGH_VRAM:
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
if mem_free_torch > mem_free_total * 0.25:

View File

@@ -300,9 +300,6 @@ class ModelPatcher:
def model_mmap_residency(self, free=False):
return comfy.model_management.module_mmap_residency(self.model, free=free)
def get_ram_usage(self):
return self.model_size()
def loaded_size(self):
return self.model.model_loaded_weight_memory

View File

@@ -928,6 +928,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
weight = state_dict.pop(weight_key, None)
if weight is None:
logging.warning(f"Missing weight for layer {layer_name}")
self.weight = None
return
manually_loaded_keys = [weight_key]
@@ -1034,6 +1035,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
if self.bias is not None:
sd["{}bias".format(prefix)] = self.bias
if self.weight is None:
return sd
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:

View File

@@ -2,6 +2,7 @@ import comfy.model_management
import comfy.memory_management
import comfy_aimdo.host_buffer
import comfy_aimdo.torch
import psutil
from comfy.cli_args import args
@@ -12,6 +13,11 @@ def pin_memory(module):
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
return
#FIXME: This is a RAM cache trigger event
ram_headroom = comfy.memory_management.RAM_CACHE_HEADROOM
#we split the difference and assume half the RAM cache headroom is for us
if ram_headroom > 0 and psutil.virtual_memory().available < (ram_headroom * 0.5):
comfy.memory_management.extra_ram_release(ram_headroom)
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:

View File

@@ -61,6 +61,7 @@ import comfy.text_encoders.newbie
import comfy.text_encoders.anima
import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
import comfy.text_encoders.qwen35
import comfy.model_patcher
import comfy.lora
@@ -279,9 +280,6 @@ class CLIP:
n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n
def get_ram_usage(self):
return self.patcher.get_ram_usage()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model)
@@ -425,13 +423,13 @@ class CLIP:
def get_key_patches(self):
return self.patcher.get_key_patches()
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
self.cond_stage_model.reset_clip_options()
self.load_model(tokens)
self.cond_stage_model.set_clip_options({"layer": None})
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
@@ -839,9 +837,6 @@ class VAE:
self.size = comfy.model_management.module_size(self.first_stage_model)
return self.size
def get_ram_usage(self):
return self.model_size()
def throw_exception_if_invalid(self):
if self.first_stage_model is None:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
@@ -1228,6 +1223,11 @@ class TEModel(Enum):
QWEN3_8B = 20
QWEN3_06B = 21
GEMMA_3_4B_VISION = 22
QWEN35_08B = 23
QWEN35_2B = 24
QWEN35_4B = 25
QWEN35_9B = 26
QWEN35_27B = 27
def detect_te_model(sd):
@@ -1267,6 +1267,17 @@ def detect_te_model(sd):
return TEModel.QWEN25_3B
if weight.shape[0] == 512:
return TEModel.QWEN25_7B
if "model.language_model.layers.0.linear_attn.A_log" in sd and "model.language_model.layers.0.input_layernorm.weight" in sd:
weight = sd['model.language_model.layers.0.input_layernorm.weight']
if weight.shape[0] == 1024:
return TEModel.QWEN35_08B
if weight.shape[0] == 2560:
return TEModel.QWEN35_4B
if weight.shape[0] == 4096:
return TEModel.QWEN35_9B
if weight.shape[0] == 5120:
return TEModel.QWEN35_27B
return TEModel.QWEN35_2B
if "model.layers.0.post_attention_layernorm.weight" in sd:
weight = sd['model.layers.0.post_attention_layernorm.weight']
if 'model.layers.0.self_attn.q_norm.weight' in sd:
@@ -1299,11 +1310,12 @@ def t5xxl_detect(clip_data):
return {}
def llama_detect(clip_data):
weight_name = "model.layers.0.self_attn.k_proj.weight"
weight_names = ["model.layers.0.self_attn.k_proj.weight", "model.layers.0.linear_attn.in_proj_a.weight"]
for sd in clip_data:
if weight_name in sd:
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
for weight_name in weight_names:
if weight_name in sd:
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
return {}
@@ -1431,6 +1443,11 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif te_model == TEModel.JINA_CLIP_2:
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
elif te_model in (TEModel.QWEN35_08B, TEModel.QWEN35_2B, TEModel.QWEN35_4B, TEModel.QWEN35_9B, TEModel.QWEN35_27B):
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
qwen35_type = {TEModel.QWEN35_08B: "qwen35_08b", TEModel.QWEN35_2B: "qwen35_2b", TEModel.QWEN35_4B: "qwen35_4b", TEModel.QWEN35_9B: "qwen35_9b", TEModel.QWEN35_27B: "qwen35_27b"}[te_model]
clip_target.clip = comfy.text_encoders.qwen35.te(**llama_detect(clip_data), model_type=qwen35_type)
clip_target.tokenizer = comfy.text_encoders.qwen35.tokenizer(model_type=qwen35_type)
elif te_model == TEModel.QWEN3_06B:
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer

View File

@@ -308,14 +308,14 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0):
if isinstance(tokens, dict):
tokens_only = next(iter(tokens.values())) # todo: get this better?
else:
tokens_only = tokens
tokens_only = [[t[0] for t in b] for b in tokens_only]
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=presence_penalty)
def parse_parentheses(string):
result = []
@@ -740,5 +740,5 @@ class SD1ClipModel(torch.nn.Module):
def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd)
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)

View File

@@ -224,7 +224,7 @@ class Qwen3_8BConfig:
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = False
lm_head: bool = True
stop_tokens = [151643, 151645]
@dataclass
@@ -655,6 +655,17 @@ class Llama2_(nn.Module):
if config.lm_head:
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def get_past_len(self, past_key_values):
return past_key_values[0][2]
def compute_freqs_cis(self, position_ids, device):
return precompute_freqs_cis(self.config.head_dim,
position_ids,
self.config.rope_theta,
self.config.rope_scale,
self.config.rope_dims,
device=device)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
if embeds is not None:
x = embeds
@@ -667,17 +678,12 @@ class Llama2_(nn.Module):
seq_len = x.shape[1]
past_len = 0
if past_key_values is not None and len(past_key_values) > 0:
past_len = past_key_values[0][2]
past_len = self.get_past_len(past_key_values)
if position_ids is None:
position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
freqs_cis = precompute_freqs_cis(self.config.head_dim,
position_ids,
self.config.rope_theta,
self.config.rope_scale,
self.config.rope_dims,
device=x.device)
freqs_cis = self.compute_freqs_cis(position_ids, x.device)
mask = None
if attention_mask is not None:
@@ -812,9 +818,16 @@ class BaseGenerate:
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
return x
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0):
device = embeds.device
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
model_config = self.model.config
past_key_values = []
for x in range(model_config.num_hidden_layers):
past_key_values.append((torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
return past_key_values
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
device = embeds.device
if stop_tokens is None:
stop_tokens = self.model.config.stop_tokens
@@ -829,11 +842,8 @@ class BaseGenerate:
if embeds.ndim == 2:
embeds = embeds.unsqueeze(0)
past_key_values = [] #kv_cache init
max_cache_len = embeds.shape[1] + max_length
for x in range(model_config.num_hidden_layers):
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
past_key_values = self.init_kv_cache(embeds.shape[0], max_cache_len, device, execution_dtype)
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
@@ -844,7 +854,7 @@ class BaseGenerate:
for step in tqdm(range(max_length), desc="Generating tokens"):
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
logits = self.logits(x)[:, -1]
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
token_id = next_token[0].item()
generated_token_ids.append(token_id)
@@ -856,7 +866,7 @@ class BaseGenerate:
return generated_token_ids
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True, presence_penalty=0.0):
if not do_sample or temperature == 0.0:
return torch.argmax(logits, dim=-1, keepdim=True)
@@ -867,6 +877,11 @@ class BaseGenerate:
for token_id in set(token_history):
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
if presence_penalty is not None and presence_penalty != 0.0:
for i in range(logits.shape[0]):
for token_id in set(token_history):
logits[i, token_id] -= presence_penalty
if temperature != 1.0:
logits = logits / temperature
@@ -897,6 +912,9 @@ class BaseGenerate:
class BaseQwen3:
def logits(self, x):
input = x[:, -1:]
if self.model.config.lm_head:
return self.model.lm_head(input)
module = self.model.embed_tokens
offload_stream = None

View File

@@ -91,11 +91,11 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
self.dtypes.add(dtype)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
tokens_only = [[t[0] for t in b] for b in tokens]
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
class DualLinearProjection(torch.nn.Module):
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
@@ -189,8 +189,8 @@ class LTXAVTEModel(torch.nn.Module):
return out.to(device=out_device, dtype=torch.float), pooled, extra
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty)
def load_sd(self, sd):
if "model.layers.47.self_attn.q_norm.weight" in sd:

View File

@@ -0,0 +1,833 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass, field
import os
import math
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy import sd1_clip
import comfy.text_encoders.qwen_vl
from .llama import BaseLlama, BaseGenerate, Llama2_, MLP, RMSNorm, apply_rope
def _qwen35_layer_types(n):
return [("full_attention" if (i + 1) % 4 == 0 else "linear_attention") for i in range(n)]
@dataclass
class Qwen35Config:
vocab_size: int = 248320
hidden_size: int = 2048
intermediate_size: int = 6144
num_hidden_layers: int = 24
# Full attention params
num_attention_heads: int = 8
num_key_value_heads: int = 2
head_dim: int = 256
partial_rotary_factor: float = 0.25
# Linear attention (DeltaNet) params
linear_num_key_heads: int = 16
linear_num_value_heads: int = 16
linear_key_head_dim: int = 128
linear_value_head_dim: int = 128
conv_kernel_size: int = 4
# Shared params
max_position_embeddings: int = 32768
rms_norm_eps: float = 1e-6
rope_theta: float = 10000000.0
mrope_section: list = field(default_factory=lambda: [11, 11, 10])
layer_types: list = field(default_factory=lambda: _qwen35_layer_types(24))
rms_norm_add: bool = True
mlp_activation: str = "silu"
qkv_bias: bool = False
final_norm: bool = True
lm_head: bool = False
stop_tokens: list = field(default_factory=lambda: [248044, 248046])
# These are needed for BaseLlama/BaseGenerate compatibility but unused directly
transformer_type: str = "qwen35_2b"
rope_dims: list = None
rope_scale: float = None
QWEN35_VISION_DEFAULTS = dict(hidden_size=1024, num_heads=16, intermediate_size=4096, depth=24, patch_size=16, temporal_patch_size=2, in_channels=3, spatial_merge_size=2, num_position_embeddings=2304)
QWEN35_MODELS = {
"qwen35_08b": dict(hidden_size=1024, intermediate_size=3584, vision=dict(hidden_size=768, num_heads=12, intermediate_size=3072, depth=12)),
"qwen35_2b": dict(hidden_size=2048, intermediate_size=6144, num_hidden_layers=24, num_attention_heads=8, num_key_value_heads=2, linear_num_value_heads=16),
"qwen35_4b": dict(hidden_size=2560, intermediate_size=9216, num_hidden_layers=32, num_attention_heads=16, num_key_value_heads=4, linear_num_value_heads=32),
"qwen35_9b": dict(hidden_size=4096, intermediate_size=12288, num_hidden_layers=32, num_attention_heads=16, num_key_value_heads=4, linear_num_value_heads=32, lm_head=True, vision=dict(hidden_size=1152, intermediate_size=4304, depth=27)),
"qwen35_27b": dict(hidden_size=5120, intermediate_size=17408, num_hidden_layers=64, num_attention_heads=24, num_key_value_heads=4, linear_num_value_heads=48, lm_head=True, vision=dict(hidden_size=1152, intermediate_size=4304, depth=27)),
}
def _make_config(model_type, config_dict={}):
overrides = QWEN35_MODELS.get(model_type, {}).copy()
overrides.pop("vision", None)
if "num_hidden_layers" in overrides:
overrides["layer_types"] = _qwen35_layer_types(overrides["num_hidden_layers"])
overrides.update(config_dict)
return Qwen35Config(**overrides)
class RMSNormGated(RMSNorm):
def forward(self, x, gate):
return super().forward(x) * F.silu(gate.to(x.dtype))
def torch_chunk_gated_delta_rule(query, key, value, g, beta, chunk_size=64, initial_state=None, output_final_state=False):
initial_dtype = query.dtype
query = F.normalize(query, dim=-1)
key = F.normalize(key, dim=-1)
query, key, value, beta, g = [x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)]
batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
query = F.pad(query, (0, 0, 0, pad_size))
key = F.pad(key, (0, 0, 0, pad_size))
value = F.pad(value, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
total_sequence_length = sequence_length + pad_size
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale
v_beta = value * beta.unsqueeze(-1)
k_beta = key * beta.unsqueeze(-1)
query, key, value, k_beta, v_beta = [x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
g = g.cumsum(dim=-1)
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
for i in range(1, chunk_size):
row = attn[..., i, :i].clone()
sub = attn[..., :i, :i].clone()
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
value = attn @ v_beta
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
last_recurrent_state = (
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
if initial_state is None
else initial_state.to(value)
)
core_attn_out = torch.zeros_like(value)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
for i in range(0, total_sequence_length // chunk_size):
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
v_new = v_i - v_prime
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
core_attn_out[:, :, i] = attn_inter + attn @ v_new
last_recurrent_state = (
last_recurrent_state * g[:, :, i, -1, None, None].exp()
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
)
if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
core_attn_out = core_attn_out[:, :, :sequence_length]
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state
def torch_causal_conv1d_update(x, conv_state, weight, bias=None):
# conv_state: [B, channels, kernel_size-1], x: [B, channels, 1]
# weight: [channels, kernel_size]
state_len = conv_state.shape[-1]
combined = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # [B, channels, kernel_size]
conv_state.copy_(combined[:, :, -state_len:])
out = (combined * weight).sum(dim=-1, keepdim=True) # [B, channels, 1]
if bias is not None:
out = out + bias.unsqueeze(0).unsqueeze(-1)
return F.silu(out).to(x.dtype)
# GatedDeltaNet - Linear Attention Layer
class GatedDeltaNet(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
hidden = config.hidden_size
self.num_key_heads = config.linear_num_key_heads
self.num_value_heads = config.linear_num_value_heads
self.key_head_dim = config.linear_key_head_dim
self.value_head_dim = config.linear_value_head_dim
self.conv_kernel_size = config.conv_kernel_size
key_dim = self.num_key_heads * self.key_head_dim
value_dim = self.num_value_heads * self.value_head_dim
self.key_dim = key_dim
self.value_dim = value_dim
conv_dim = key_dim * 2 + value_dim
self.in_proj_qkv = ops.Linear(hidden, conv_dim, bias=False, device=device, dtype=dtype)
self.in_proj_z = ops.Linear(hidden, value_dim, bias=False, device=device, dtype=dtype)
self.in_proj_b = ops.Linear(hidden, self.num_value_heads, bias=False, device=device, dtype=dtype)
self.in_proj_a = ops.Linear(hidden, self.num_value_heads, bias=False, device=device, dtype=dtype)
self.out_proj = ops.Linear(value_dim, hidden, bias=False, device=device, dtype=dtype)
self.dt_bias = nn.Parameter(torch.empty(self.num_value_heads, device=device, dtype=dtype))
self.A_log = nn.Parameter(torch.empty(self.num_value_heads, device=device, dtype=dtype))
self.conv1d = ops.Conv1d(in_channels=conv_dim, out_channels=conv_dim, bias=False, kernel_size=self.conv_kernel_size,
groups=conv_dim, padding=self.conv_kernel_size - 1, device=device, dtype=dtype)
self.norm = RMSNormGated(self.value_head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype)
def forward(self, x, past_key_value=None, **kwargs):
batch_size, seq_len, _ = x.shape
use_recurrent = (
past_key_value is not None
and past_key_value[2] > 0
and seq_len == 1
)
# Projections (shared)
mixed_qkv = self.in_proj_qkv(x).transpose(1, 2) # [B, conv_dim, seq_len]
z = self.in_proj_z(x)
b = self.in_proj_b(x)
a = self.in_proj_a(x)
# Conv1d
if use_recurrent:
recurrent_state, conv_state, step_index = past_key_value
conv_weight = comfy.model_management.cast_to_device(self.conv1d.weight, mixed_qkv.device, mixed_qkv.dtype).squeeze(1)
conv_bias = comfy.model_management.cast_to_device(self.conv1d.bias, mixed_qkv.device, mixed_qkv.dtype) if self.conv1d.bias is not None else None
mixed_qkv = torch_causal_conv1d_update(mixed_qkv, conv_state, conv_weight, conv_bias)
else:
if past_key_value is not None:
recurrent_state, conv_state, step_index = past_key_value
conv_state_init = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
conv_state.copy_(conv_state_init[:, :, -conv_state.shape[-1]:])
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
# Split QKV and compute beta/g
mixed_qkv = mixed_qkv.transpose(1, 2) # [B, seq_len, conv_dim]
query, key, value = mixed_qkv.split([self.key_dim, self.key_dim, self.value_dim], dim=-1)
beta = b.sigmoid()
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias.float())
# Delta rule
if use_recurrent:
# single-token path: work in [B, heads, dim] without seq dim
query = query.reshape(batch_size, self.num_key_heads, self.key_head_dim)
key = key.reshape(batch_size, self.num_key_heads, self.key_head_dim)
value = value.reshape(batch_size, self.num_value_heads, self.value_head_dim)
if self.num_value_heads != self.num_key_heads:
rep = self.num_value_heads // self.num_key_heads
query = query.repeat_interleave(rep, dim=1)
key = key.repeat_interleave(rep, dim=1)
scale = self.key_head_dim ** -0.5
q = F.normalize(query.float(), dim=-1) * scale
k = F.normalize(key.float(), dim=-1)
v = value.float()
beta_t = beta.reshape(batch_size, -1)
g_t = g.reshape(batch_size, -1).exp()
# In-place state update: [B, heads, k_dim, v_dim]
recurrent_state.mul_(g_t[:, :, None, None])
kv_mem = torch.einsum('bhk,bhkv->bhv', k, recurrent_state)
delta = (v - kv_mem) * beta_t[:, :, None]
recurrent_state.add_(k.unsqueeze(-1) * delta.unsqueeze(-2))
core_attn_out = torch.einsum('bhk,bhkv->bhv', q, recurrent_state)
core_attn_out = core_attn_out.to(x.dtype).unsqueeze(1)
present_key_value = (recurrent_state, conv_state, step_index + 1)
else:
query = query.reshape(batch_size, seq_len, -1, self.key_head_dim)
key = key.reshape(batch_size, seq_len, -1, self.key_head_dim)
value = value.reshape(batch_size, seq_len, -1, self.value_head_dim)
if self.num_value_heads != self.num_key_heads:
rep = self.num_value_heads // self.num_key_heads
query = query.repeat_interleave(rep, dim=2)
key = key.repeat_interleave(rep, dim=2)
core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule(
query, key, value, g=g, beta=beta,
initial_state=None,
output_final_state=past_key_value is not None,
)
present_key_value = None
if past_key_value is not None:
if last_recurrent_state is not None:
recurrent_state.copy_(last_recurrent_state.to(recurrent_state.dtype))
present_key_value = (recurrent_state, conv_state, step_index + seq_len)
# Gated norm + output projection (shared)
core_attn_out = self.norm(core_attn_out.reshape(-1, self.value_head_dim), z.reshape(-1, self.value_head_dim))
output = self.out_proj(core_attn_out.reshape(batch_size, seq_len, -1))
return output, present_key_value
# GatedAttention - Full Attention with output gating
def precompute_partial_rope(head_dim, rotary_dim, position_ids, theta, device=None, mrope_section=None):
"""Compute RoPE frequencies for partial rotary embeddings."""
theta_numerator = torch.arange(0, rotary_dim, 2, device=device).float()
inv_freq = 1.0 / (theta ** (theta_numerator / rotary_dim))
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if mrope_section is not None and position_ids.shape[0] == 3:
mrope_section_2 = [s * 2 for s in mrope_section]
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section_2, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section_2, dim=-1))], dim=-1).unsqueeze(0)
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
sin_split = sin.shape[-1] // 2
return (cos, sin[..., :sin_split], -sin[..., sin_split:])
def apply_partial_rope(xq, xk, freqs_cis, rotary_dim):
"""Apply RoPE to only the first rotary_dim dimensions."""
xq_rot = xq[..., :rotary_dim]
xq_pass = xq[..., rotary_dim:]
xk_rot = xk[..., :rotary_dim]
xk_pass = xk[..., rotary_dim:]
xq_rot, xk_rot = apply_rope(xq_rot, xk_rot, freqs_cis)
xq = torch.cat([xq_rot, xq_pass], dim=-1)
xk = torch.cat([xk_rot, xk_pass], dim=-1)
return xq, xk
class GatedAttention(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.hidden_size = config.hidden_size
self.inner_size = self.num_heads * self.head_dim
self.rotary_dim = int(self.head_dim * config.partial_rotary_factor)
# q_proj outputs 2x: query + gate
self.q_proj = ops.Linear(config.hidden_size, self.inner_size * 2, bias=config.qkv_bias, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
# QK norms with (1+weight) scaling
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, past_key_value=None):
batch_size, seq_length, _ = x.shape
# Project Q (with gate), K, V
qg = self.q_proj(x)
# Split into query and gate: each is [B, seq, inner_size]
qg = qg.view(batch_size, seq_length, self.num_heads, self.head_dim * 2)
xq, gate = qg[..., :self.head_dim], qg[..., self.head_dim:]
gate = gate.reshape(batch_size, seq_length, -1) # [B, seq, inner_size]
xk = self.k_proj(x)
xv = self.v_proj(x)
xq = self.q_norm(xq).transpose(1, 2) # [B, heads, seq, head_dim]
xk = self.k_norm(xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim)).transpose(1, 2)
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
# Apply partial RoPE
xq, xk = apply_partial_rope(xq, xk, freqs_cis, self.rotary_dim)
# KV cache
present_key_value = None
if past_key_value is not None:
past_key, past_value, index = past_key_value
num_tokens = xk.shape[2]
if past_key.shape[2] >= (index + num_tokens):
past_key[:, :, index:index + num_tokens] = xk
past_value[:, :, index:index + num_tokens] = xv
xk = past_key[:, :, :index + num_tokens]
xv = past_value[:, :, :index + num_tokens]
present_key_value = (past_key, past_value, index + num_tokens)
else:
if index > 0:
xk = torch.cat((past_key[:, :, :index], xk), dim=2)
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
present_key_value = (xk, xv, index + num_tokens)
# Expand KV heads for GQA
if self.num_heads != self.num_kv_heads:
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
output = output * gate.sigmoid()
return self.o_proj(output), present_key_value
# Hybrid Transformer Block
class Qwen35TransformerBlock(nn.Module):
def __init__(self, config, index, device=None, dtype=None, ops=None):
super().__init__()
self.layer_type = config.layer_types[index]
if self.layer_type == "linear_attention":
self.linear_attn = GatedDeltaNet(config, device=device, dtype=dtype, ops=ops)
else:
self.self_attn = GatedAttention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, past_key_value=None):
if self.layer_type == "linear_attention":
h, present_key_value = self.linear_attn(self.input_layernorm(x), attention_mask=attention_mask, past_key_value=past_key_value)
else:
h, present_key_value = self.self_attn(self.input_layernorm(x), attention_mask=attention_mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, past_key_value=past_key_value)
x = x + h
x = x + self.mlp(self.post_attention_layernorm(x))
return x, present_key_value
# Qwen35 Transformer Backbone
class Qwen35Transformer(Llama2_):
def __init__(self, config, device=None, dtype=None, ops=None):
nn.Module.__init__(self)
self.config = config
self.vocab_size = config.vocab_size
self.normalize_in = False
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList([
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)
for i in range(config.num_hidden_layers)
])
if config.final_norm:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
else:
self.norm = None
if config.lm_head:
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def get_past_len(self, past_key_values):
for i, layer in enumerate(self.layers):
if layer.layer_type == "full_attention":
if len(past_key_values) > i:
return past_key_values[i][2]
break
return 0
def compute_freqs_cis(self, position_ids, device):
rotary_dim = int(self.config.head_dim * self.config.partial_rotary_factor)
return precompute_partial_rope(
self.config.head_dim, rotary_dim, position_ids,
self.config.rope_theta, device=device,
mrope_section=self.config.mrope_section,
)
# Vision Encoder
class Qwen35VisionPatchEmbed(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.patch_size = config["patch_size"]
self.temporal_patch_size = config["temporal_patch_size"]
self.in_channels = config["in_channels"]
self.embed_dim = config["hidden_size"]
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
self.proj = ops.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True, device=device, dtype=dtype)
def forward(self, x):
target_dtype = self.proj.weight.dtype
x = x.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size)
return self.proj(x.to(target_dtype)).view(-1, self.embed_dim)
class Qwen35VisionMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size, device=None, dtype=None, ops=None):
super().__init__()
self.linear_fc1 = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
self.linear_fc2 = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype)
def forward(self, hidden_state):
return self.linear_fc2(F.gelu(self.linear_fc1(hidden_state), approximate="tanh"))
class Qwen35VisionRotaryEmbedding(nn.Module):
def __init__(self, dim, theta=10000.0):
super().__init__()
self.dim = dim
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen):
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(seq, self.inv_freq)
return freqs
class Qwen35VisionAttention(nn.Module):
def __init__(self, hidden_size, num_heads, device=None, dtype=None, ops=None):
super().__init__()
self.dim = hidden_size
self.num_heads = num_heads
self.head_dim = self.dim // self.num_heads
self.qkv = ops.Linear(self.dim, self.dim * 3, bias=True, device=device, dtype=dtype)
self.proj = ops.Linear(self.dim, self.dim, device=device, dtype=dtype)
def forward(self, x, cu_seqlens, position_embeddings, optimized_attention=None):
seq_length = x.shape[0]
query_states, key_states, value_states = (
self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)
query_states, key_states = apply_rope(query_states, key_states, position_embeddings)
# Process per-sequence attention
lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_splits = torch.split(query_states, lengths, dim=0)
k_splits = torch.split(key_states, lengths, dim=0)
v_splits = torch.split(value_states, lengths, dim=0)
attn_outputs = []
for q, k, v in zip(q_splits, k_splits, v_splits):
q = q.transpose(0, 1).unsqueeze(0)
k = k.transpose(0, 1).unsqueeze(0)
v = v.transpose(0, 1).unsqueeze(0)
attn_outputs.append(optimized_attention(q, k, v, self.num_heads, skip_reshape=True))
attn_output = torch.cat(attn_outputs, dim=1)
attn_output = attn_output.reshape(seq_length, -1)
return self.proj(attn_output)
class Qwen35VisionBlock(nn.Module):
def __init__(self, hidden_size, num_heads, intermediate_size, device=None, dtype=None, ops=None):
super().__init__()
self.norm1 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
self.norm2 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
self.attn = Qwen35VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops)
self.mlp = Qwen35VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
def forward(self, x, cu_seqlens, position_embeddings, optimized_attention=None):
x = x + self.attn(self.norm1(x), cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
return x + self.mlp(self.norm2(x))
class Qwen35VisionPatchMerger(nn.Module):
def __init__(self, hidden_size, spatial_merge_size, out_hidden_size, device=None, dtype=None, ops=None):
super().__init__()
merge_dim = hidden_size * (spatial_merge_size ** 2)
self.norm = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
self.linear_fc1 = ops.Linear(merge_dim, merge_dim, device=device, dtype=dtype)
self.linear_fc2 = ops.Linear(merge_dim, out_hidden_size, device=device, dtype=dtype)
self.merge_dim = merge_dim
def forward(self, x):
x = self.norm(x).view(-1, self.merge_dim)
return self.linear_fc2(F.gelu(self.linear_fc1(x)))
class Qwen35VisionModel(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.spatial_merge_size = config["spatial_merge_size"]
self.patch_size = config["patch_size"]
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
self.hidden_size = config["hidden_size"]
self.num_heads = config["num_heads"]
self.num_position_embeddings = config["num_position_embeddings"]
self.patch_embed = Qwen35VisionPatchEmbed(config, device=device, dtype=dtype, ops=ops)
self.pos_embed = ops.Embedding(self.num_position_embeddings, self.hidden_size, device=device, dtype=dtype)
self.num_grid_per_side = int(self.num_position_embeddings ** 0.5)
self.rotary_pos_emb = Qwen35VisionRotaryEmbedding(self.hidden_size // self.num_heads // 2)
self.blocks = nn.ModuleList([
Qwen35VisionBlock(self.hidden_size, self.num_heads, config["intermediate_size"], device=device, dtype=dtype, ops=ops)
for _ in range(config["depth"])
])
self.merger = Qwen35VisionPatchMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops)
def rot_pos_emb(self, grid_thw):
merge_size = self.spatial_merge_size
grid_thw_list = grid_thw.tolist()
max_hw = max(max(h, w) for _, h, w in grid_thw_list)
freq_table = self.rotary_pos_emb(max_hw)
device = freq_table.device
total_tokens = sum(int(t * h * w) for t, h, w in grid_thw_list)
pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
offset = 0
for num_frames, height, width in grid_thw_list:
num_frames, height, width = int(num_frames), int(height), int(width)
merged_h, merged_w = height // merge_size, width // merge_size
block_rows = torch.arange(merged_h, device=device)
block_cols = torch.arange(merged_w, device=device)
intra_row = torch.arange(merge_size, device=device)
intra_col = torch.arange(merge_size, device=device)
row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
coords = torch.stack((row_idx, col_idx), dim=-1)
if num_frames > 1:
coords = coords.repeat(num_frames, 1)
num_tokens = coords.shape[0]
pos_ids[offset:offset + num_tokens] = coords
offset += num_tokens
embeddings = freq_table[pos_ids]
embeddings = embeddings.flatten(1)
return embeddings
def fast_pos_embed_interpolate(self, grid_thw):
grid_thw_list = grid_thw.tolist()
grid_ts = [int(row[0]) for row in grid_thw_list]
grid_hs = [int(row[1]) for row in grid_thw_list]
grid_ws = [int(row[2]) for row in grid_thw_list]
device = self.pos_embed.weight.device
idx_list = [[] for _ in range(4)]
weight_list = [[] for _ in range(4)]
for t, h, w in grid_thw_list:
h, w = int(h), int(w)
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h, device=device)
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w, device=device)
h_idxs_floor = h_idxs.int()
w_idxs_floor = w_idxs.int()
h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
dh = h_idxs - h_idxs_floor
dw = w_idxs - w_idxs_floor
base_h = h_idxs_floor * self.num_grid_per_side
base_h_ceil = h_idxs_ceil * self.num_grid_per_side
indices = [
(base_h[None].T + w_idxs_floor[None]).flatten(),
(base_h[None].T + w_idxs_ceil[None]).flatten(),
(base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
(base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
]
weights = [
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
((1 - dh)[None].T * dw[None]).flatten(),
(dh[None].T * (1 - dw)[None]).flatten(),
(dh[None].T * dw[None]).flatten(),
]
for j in range(4):
idx_list[j].extend(indices[j].tolist())
weight_list[j].extend(weights[j].tolist())
idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device)
weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device)
pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None]
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
patch_pos_embeds_permute = []
merge_size = self.spatial_merge_size
for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
pos_embed = pos_embed.repeat(t, 1)
pos_embed = (
pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
.permute(0, 1, 3, 2, 4, 5)
.flatten(0, 4)
)
patch_pos_embeds_permute.append(pos_embed)
return torch.cat(patch_pos_embeds_permute)
def forward(self, x, grid_thw):
x = self.patch_embed(x)
pos_embeds = self.fast_pos_embed_interpolate(grid_thw).to(x.device)
x = x + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)
seq_len = x.shape[0]
x = x.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos().unsqueeze(-2)
sin = emb.sin().unsqueeze(-2)
sin_half = sin.shape[-1] // 2
position_embeddings = (cos, sin[..., :sin_half], -sin[..., sin_half:])
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
for blk in self.blocks:
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
merged = self.merger(x)
return merged
# Model Wrapper
class Qwen35(BaseLlama, BaseGenerate, torch.nn.Module):
model_type = "qwen35_2b"
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = _make_config(self.model_type, config_dict)
self.num_layers = config.num_hidden_layers
self.model = Qwen35Transformer(config, device=device, dtype=dtype, ops=operations)
vision_overrides = QWEN35_MODELS.get(self.model_type, {}).get("vision", {})
vision_config = {**QWEN35_VISION_DEFAULTS, **vision_overrides, "out_hidden_size": config.hidden_size}
self.visual = Qwen35VisionModel(vision_config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
def preprocess_embed(self, embed, device):
if embed["type"] == "image":
image, grid = comfy.text_encoders.qwen_vl.process_qwen2vl_images(embed["data"], patch_size=16)
return self.visual(image.to(device, dtype=torch.float32), grid), grid
return None, None
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[], past_key_values=None):
grid = None
position_ids = None
offset = 0
for e in embeds_info:
if e.get("type") == "image":
grid = e.get("extra", None)
start = e.get("index")
if position_ids is None:
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
end = e.get("size") + start
len_max = int(grid.max()) // 2
start_next = len_max + start
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
position_ids[0, start:end] = start + offset
max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
max_d = int(grid[0][2]) // 2
position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
offset += len_max - (end - start)
if grid is None:
position_ids = None
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids, past_key_values=past_key_values)
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
model_config = self.model.config
past_key_values = []
for i in range(model_config.num_hidden_layers):
if model_config.layer_types[i] == "linear_attention":
recurrent_state = torch.zeros(
[batch, model_config.linear_num_value_heads, model_config.linear_key_head_dim, model_config.linear_value_head_dim],
device=device, dtype=torch.float32
)
conv_dim = model_config.linear_num_key_heads * model_config.linear_key_head_dim * 2 + model_config.linear_num_value_heads * model_config.linear_value_head_dim
conv_state = torch.zeros(
[batch, conv_dim, model_config.conv_kernel_size - 1],
device=device, dtype=execution_dtype
)
past_key_values.append((recurrent_state, conv_state, 0))
else:
past_key_values.append((
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
0
))
return past_key_values
# Tokenizer and Text Encoder Wrappers
class Qwen35Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, embedding_size=2048, embedding_key="qwen35_2b"):
from transformers import Qwen2Tokenizer
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen35_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=Qwen2Tokenizer,
has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=248044, tokenizer_data=tokenizer_data)
class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, model_type="qwen35_2b"):
embedding_size = QWEN35_MODELS.get(model_type, {}).get("hidden_size", 2048)
tokenizer = lambda *a, **kw: Qwen35Tokenizer(*a, **kw, embedding_size=embedding_size, embedding_key=model_type)
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=model_type, tokenizer=tokenizer)
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
self.llama_template_images = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs):
image = kwargs.get("image", None)
if image is not None and len(images) == 0:
images = [image]
skip_template = False
if text.startswith('<|im_start|>'):
skip_template = True
if prevent_empty_text and text == '':
text = ' '
if skip_template:
llama_text = text
else:
if llama_template is None:
if len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
if not thinking:
llama_text += "<think>\n</think>\n"
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
key_name = next(iter(tokens))
embed_count = 0
qwen_tokens = tokens[key_name]
for r in qwen_tokens:
for i in range(len(r)):
if r[i][0] == 248056: # <|image_pad|>
if len(images) > embed_count:
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
embed_count += 1
return tokens
class Qwen35ClipModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}, model_type="qwen35_2b"):
class Qwen35_(Qwen35):
pass
Qwen35_.model_type = model_type
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={},
dtype=dtype, special_tokens={"pad": 248044}, layer_norm_hidden_state=False,
model_class=Qwen35_, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Qwen35TEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, model_type="qwen35_2b"):
clip_model = lambda **kw: Qwen35ClipModel(**kw, model_type=model_type)
super().__init__(device=device, dtype=dtype, name=model_type, clip_model=clip_model, model_options=model_options)
def tokenizer(model_type="qwen35_2b"):
class Qwen35ImageTokenizer_(Qwen35ImageTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type=model_type)
return Qwen35ImageTokenizer_
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen35_2b"):
class Qwen35TEModel_(Qwen35TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options, model_type=model_type)
return Qwen35TEModel_

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -1373,6 +1373,7 @@ class NodeInfoV1:
price_badge: dict | None = None
search_aliases: list[str]=None
essentials_category: str=None
has_intermediate_output: bool=None
@dataclass
@@ -1496,6 +1497,16 @@ class Schema:
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
essentials_category: str | None = None
"""Optional category for the Essentials tab. Path-based like category field (e.g., 'Basic', 'Image Tools/Editing')."""
has_intermediate_output: bool=False
"""Flags this node as having intermediate output that should persist across page refreshes.
Nodes with this flag behave like output nodes (their UI results are cached and resent
to the frontend) but do NOT automatically get added to the execution list. This means
they will only execute if they are on the dependency path of a real output node.
Use this for nodes with interactive/operable UI regions that produce intermediate outputs
(e.g., Image Crop, Painter) rather than final outputs (e.g., Save Image).
"""
def validate(self):
'''Validate the schema:
@@ -1595,6 +1606,7 @@ class Schema:
category=self.category,
description=self.description,
output_node=self.is_output_node,
has_intermediate_output=self.has_intermediate_output,
deprecated=self.is_deprecated,
experimental=self.is_experimental,
dev_only=self.is_dev_only,
@@ -1886,6 +1898,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls.GET_SCHEMA()
return cls._OUTPUT_NODE
_HAS_INTERMEDIATE_OUTPUT = None
@final
@classproperty
def HAS_INTERMEDIATE_OUTPUT(cls): # noqa
if cls._HAS_INTERMEDIATE_OUTPUT is None:
cls.GET_SCHEMA()
return cls._HAS_INTERMEDIATE_OUTPUT
_INPUT_IS_LIST = None
@final
@classproperty
@@ -1978,6 +1998,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls._API_NODE = schema.is_api_node
if cls._OUTPUT_NODE is None:
cls._OUTPUT_NODE = schema.is_output_node
if cls._HAS_INTERMEDIATE_OUTPUT is None:
cls._HAS_INTERMEDIATE_OUTPUT = schema.has_intermediate_output
if cls._INPUT_IS_LIST is None:
cls._INPUT_IS_LIST = schema.is_input_list
if cls._NOT_IDEMPOTENT is None:

View File

@@ -201,6 +201,16 @@ async def get_image_from_response(response: GeminiGenerateContentResponse, thoug
returned_image = await download_url_to_image_tensor(part.fileData.fileUri)
image_tensors.append(returned_image)
if len(image_tensors) == 0:
if not thought:
# No images generated --> extract text response for a meaningful error
model_message = get_text_from_response(response).strip()
if model_message:
raise ValueError(f"Gemini did not generate an image. Model response: {model_message}")
raise ValueError(
"Gemini did not generate an image. "
"Try rephrasing your prompt or changing the response modality to 'IMAGE+TEXT' "
"to see the model's reasoning."
)
return torch.zeros((1, 1024, 1024, 4))
return torch.cat(image_tensors, dim=0)

View File

@@ -132,7 +132,7 @@ class TencentTextToModelNode(IO.ComfyNode):
tooltip="The LowPoly option is unavailable for the `3.1` model.",
),
IO.String.Input("prompt", multiline=True, default="", tooltip="Supports up to 1024 characters."),
IO.Int.Input("face_count", default=500000, min=40000, max=1500000),
IO.Int.Input("face_count", default=500000, min=3000, max=1500000),
IO.DynamicCombo.Input(
"generate_type",
options=[
@@ -251,7 +251,7 @@ class TencentImageToModelNode(IO.ComfyNode):
IO.Image.Input("image_left", optional=True),
IO.Image.Input("image_right", optional=True),
IO.Image.Input("image_back", optional=True),
IO.Int.Input("face_count", default=500000, min=40000, max=1500000),
IO.Int.Input("face_count", default=500000, min=3000, max=1500000),
IO.DynamicCombo.Input(
"generate_type",
options=[
@@ -422,6 +422,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
outputs=[
IO.File3DOBJ.Output(display_name="OBJ"),
IO.File3DFBX.Output(display_name="FBX"),
IO.Image.Output(display_name="uv_image"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -468,9 +469,16 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
uv_image_file = get_file_from_response(result.ResultFile3Ds, "uv_image", raise_if_not_found=False)
uv_image = (
await download_url_to_image_tensor(uv_image_file.Url)
if uv_image_file is not None
else torch.zeros(1, 1, 1, 3)
)
return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
uv_image,
)

View File

@@ -145,7 +145,20 @@ class ReveImageCreateNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
depends_on=IO.PriceBadgeDepends(
widgets=["upscale", "upscale.upscale_factor"],
),
expr="""
(
$factor := $lookup(widgets, "upscale.upscale_factor");
$fmt := {"approximate": true, "note": "(base)"};
widgets.upscale = "enabled" ? (
$factor = 4 ? {"type": "usd", "usd": 0.0762, "format": $fmt}
: $factor = 3 ? {"type": "usd", "usd": 0.0591, "format": $fmt}
: {"type": "usd", "usd": 0.0457, "format": $fmt}
) : {"type": "usd", "usd": 0.03432, "format": $fmt}
)
""",
),
)
@@ -225,13 +238,21 @@ class ReveImageEditNode(IO.ComfyNode):
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model"],
widgets=["model", "upscale", "upscale.upscale_factor"],
),
expr="""
(
$fmt := {"approximate": true, "note": "(base)"};
$isFast := $contains(widgets.model, "fast");
$base := $isFast ? 0.01001 : 0.0572;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
$enabled := widgets.upscale = "enabled";
$factor := $lookup(widgets, "upscale.upscale_factor");
$isFast
? {"type": "usd", "usd": 0.01001, "format": $fmt}
: $enabled ? (
$factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt}
: $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt}
: {"type": "usd", "usd": 0.0686, "format": $fmt}
) : {"type": "usd", "usd": 0.0572, "format": $fmt}
)
""",
),
@@ -327,13 +348,21 @@ class ReveImageRemixNode(IO.ComfyNode):
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model"],
widgets=["model", "upscale", "upscale.upscale_factor"],
),
expr="""
(
$fmt := {"approximate": true, "note": "(base)"};
$isFast := $contains(widgets.model, "fast");
$base := $isFast ? 0.01001 : 0.0572;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
$enabled := widgets.upscale = "enabled";
$factor := $lookup(widgets, "upscale.upscale_factor");
$isFast
? {"type": "usd", "usd": 0.01001, "format": $fmt}
: $enabled ? (
$factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt}
: $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt}
: {"type": "usd", "usd": 0.0686, "format": $fmt}
) : {"type": "usd", "usd": 0.0572, "format": $fmt}
)
""",
),

View File

@@ -38,6 +38,7 @@ from comfy_api_nodes.util import (
UPSCALER_MODELS_MAP = {
"Starlight (Astra) Fast": "slf-1",
"Starlight (Astra) Creative": "slc-1",
"Starlight Precise 2.5": "slp-2.5",
}

View File

@@ -1,6 +1,5 @@
import asyncio
import bisect
import gc
import itertools
import psutil
import time
@@ -475,6 +474,10 @@ class LRUCache(BasicCache):
self._mark_used(node_id)
return await self._set_immediate(node_id, value)
def set_local(self, node_id, value):
self._mark_used(node_id)
BasicCache.set_local(self, node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes
await super()._ensure_subcache(node_id, children_ids)
@@ -489,15 +492,10 @@ class LRUCache(BasicCache):
return self
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
#Small baseline weight used when a cache entry has no measurable CPU tensors.
#Keeps unknown-sized entries in eviction scoring without dominating tensor-backed entries.
RAM_CACHE_HYSTERESIS = 1.1
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
RAM_CACHE_DEFAULT_RAM_USAGE = 0.05
#Exponential bias towards evicting older workflows so garbage will be taken out
#in constantly changing setups.
@@ -521,19 +519,17 @@ class RAMPressureCache(LRUCache):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return await super().get(node_id)
def poll(self, ram_headroom):
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)
def set_local(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
super().set_local(node_id, value)
if _ram_gb() > ram_headroom:
return
gc.collect()
if _ram_gb() > ram_headroom:
def ram_release(self, target):
if psutil.virtual_memory().available >= target:
return
clean_list = []
for key, (outputs, _), in self.cache.items():
for key, cache_entry in self.cache.items():
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
@@ -542,22 +538,20 @@ class RAMPressureCache(LRUCache):
if outputs is None:
return
for output in outputs:
if isinstance(output, list):
if isinstance(output, (list, tuple)):
scan_list_for_ram_usage(output)
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
#score Tensors at a 50% discount for RAM usage as they are likely to
#be high value intermediates
ram_usage += (output.numel() * output.element_size()) * 0.5
elif hasattr(output, "get_ram_usage"):
ram_usage += output.get_ram_usage()
scan_list_for_ram_usage(outputs)
ram_usage += output.numel() * output.element_size()
scan_list_for_ram_usage(cache_entry.outputs)
oom_score *= ram_usage
#In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
while psutil.virtual_memory().available < target and clean_list:
_, _, key = clean_list.pop()
del self.cache[key]
gc.collect()
self.used_generation.pop(key, None)
self.timestamps.pop(key, None)
self.children.pop(key, None)

View File

@@ -87,7 +87,9 @@ class SizeModeInput(TypedDict):
MAX_IMAGES = 5 # u_image0-4
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
MAX_UNIFORMS = 20 # u_float0-19, u_int0-19
MAX_BOOLS = 10 # u_bool0-9
MAX_CURVES = 4 # u_curve0-3 (1D LUT textures)
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
# Vertex shader using gl_VertexID trick - no VBO needed.
@@ -497,6 +499,8 @@ def _render_shader_batch(
image_batches: list[list[np.ndarray]],
floats: list[float],
ints: list[int],
bools: list[bool] | None = None,
curves: list[np.ndarray] | None = None,
) -> list[list[np.ndarray]]:
"""
Render a fragment shader for multiple batches efficiently.
@@ -511,6 +515,8 @@ def _render_shader_batch(
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
floats: List of float uniforms
ints: List of int uniforms
bools: List of bool uniforms (passed as int 0/1 to GLSL bool uniforms)
curves: List of 1D LUT arrays (float32) of arbitrary size for u_curve0-N
Returns:
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
@@ -533,11 +539,17 @@ def _render_shader_batch(
# Detect multi-pass rendering
num_passes = _detect_pass_count(fragment_code)
if bools is None:
bools = []
if curves is None:
curves = []
# Track resources for cleanup
program = None
fbo = None
output_textures = []
input_textures = []
curve_textures = []
ping_pong_textures = []
ping_pong_fbos = []
@@ -624,6 +636,28 @@ def _render_shader_batch(
if loc >= 0:
gl.glUniform1i(loc, v)
for i, v in enumerate(bools):
loc = gl.glGetUniformLocation(program, f"u_bool{i}")
if loc >= 0:
gl.glUniform1i(loc, 1 if v else 0)
# Create 1D LUT textures for curves (bound after image texture units)
for i, lut in enumerate(curves):
tex = gl.glGenTextures(1)
curve_textures.append(tex)
unit = MAX_IMAGES + i
gl.glActiveTexture(gl.GL_TEXTURE0 + unit)
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_R32F, len(lut), 1, 0, gl.GL_RED, gl.GL_FLOAT, lut)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
loc = gl.glGetUniformLocation(program, f"u_curve{i}")
if loc >= 0:
gl.glUniform1i(loc, unit)
# Get u_pass uniform location for multi-pass
pass_loc = gl.glGetUniformLocation(program, "u_pass")
@@ -718,6 +752,8 @@ def _render_shader_batch(
for tex in input_textures:
gl.glDeleteTextures(int(tex))
for tex in curve_textures:
gl.glDeleteTextures(int(tex))
for tex in output_textures:
gl.glDeleteTextures(int(tex))
for tex in ping_pong_textures:
@@ -754,6 +790,20 @@ class GLSLShader(io.ComfyNode):
max=MAX_UNIFORMS,
)
bool_template = io.Autogrow.TemplatePrefix(
io.Boolean.Input("bool", default=False),
prefix="u_bool",
min=0,
max=MAX_BOOLS,
)
curve_template = io.Autogrow.TemplatePrefix(
io.Curve.Input("curve"),
prefix="u_curve",
min=0,
max=MAX_CURVES,
)
return io.Schema(
node_id="GLSLShader",
display_name="GLSL Shader",
@@ -762,6 +812,8 @@ class GLSLShader(io.ComfyNode):
"Apply GLSL ES fragment shaders to images. "
"u_resolution (vec2) is always available."
),
is_experimental=True,
has_intermediate_output=True,
inputs=[
io.String.Input(
"fragment_shader",
@@ -796,6 +848,8 @@ class GLSLShader(io.ComfyNode):
io.Autogrow.Input("images", template=image_template, tooltip=f"Images are available as u_image0-{MAX_IMAGES-1} (sampler2D) in the shader code"),
io.Autogrow.Input("floats", template=float_template, tooltip=f"Floats are available as u_float0-{MAX_UNIFORMS-1} in the shader code"),
io.Autogrow.Input("ints", template=int_template, tooltip=f"Ints are available as u_int0-{MAX_UNIFORMS-1} in the shader code"),
io.Autogrow.Input("bools", template=bool_template, tooltip=f"Booleans are available as u_bool0-{MAX_BOOLS-1} (bool) in the shader code"),
io.Autogrow.Input("curves", template=curve_template, tooltip=f"Curves are available as u_curve0-{MAX_CURVES-1} (sampler2D, 1D LUT) in the shader code. Sample with texture(u_curve0, vec2(x, 0.5)).r"),
],
outputs=[
io.Image.Output(display_name="IMAGE0", tooltip="Available via layout(location = 0) out vec4 fragColor0 in the shader code"),
@@ -813,13 +867,19 @@ class GLSLShader(io.ComfyNode):
images: io.Autogrow.Type,
floats: io.Autogrow.Type = None,
ints: io.Autogrow.Type = None,
bools: io.Autogrow.Type = None,
curves: io.Autogrow.Type = None,
**kwargs,
) -> io.NodeOutput:
image_list = [v for v in images.values() if v is not None]
float_list = (
[v if v is not None else 0.0 for v in floats.values()] if floats else []
)
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
bool_list = [v if v is not None else False for v in bools.values()] if bools else []
curve_luts = [v.to_lut().astype(np.float32) for v in curves.values() if v is not None] if curves else []
if not image_list:
raise ValueError("At least one input image is required")
@@ -846,6 +906,8 @@ class GLSLShader(io.ComfyNode):
image_batches,
float_list,
int_list,
bool_list,
curve_luts,
)
# Collect outputs into tensors

View File

@@ -59,6 +59,7 @@ class ImageCropV2(IO.ComfyNode):
display_name="Image Crop",
category="image/transform",
essentials_category="Image Tools",
has_intermediate_output=True,
inputs=[
IO.Image.Input("image"),
IO.BoundingBox.Input("crop_region", component="ImageCrop"),

View File

@@ -44,8 +44,13 @@ class NumberConvertNode(io.ComfyNode):
def execute(cls, value) -> io.NodeOutput:
if isinstance(value, bool):
float_val = 1.0 if value else 0.0
elif isinstance(value, (int, float)):
int_val = 1 if value else 0
elif isinstance(value, int):
float_val = float(value)
int_val = value
elif isinstance(value, float):
float_val = value
int_val = int(value)
elif isinstance(value, str):
text = value.strip()
if not text:
@@ -56,6 +61,14 @@ class NumberConvertNode(io.ComfyNode):
raise ValueError(
f"Cannot convert string to number: {value!r}"
) from None
if not math.isfinite(float_val):
raise ValueError(
f"Cannot convert non-finite value to number: {float_val}"
)
try:
int_val = int(text)
except ValueError:
int_val = int(float_val)
else:
raise TypeError(
f"Unsupported input type: {type(value).__name__}"
@@ -66,7 +79,7 @@ class NumberConvertNode(io.ComfyNode):
f"Cannot convert non-finite value to number: {float_val}"
)
return io.NodeOutput(float_val, int(float_val))
return io.NodeOutput(float_val, int_val)
class NumberConvertExtension(ComfyExtension):

View File

@@ -30,6 +30,7 @@ class PainterNode(io.ComfyNode):
node_id="Painter",
display_name="Painter",
category="image",
has_intermediate_output=True,
inputs=[
io.Image.Input(
"image",

View File

@@ -67,11 +67,11 @@ class Blend(io.ComfyNode):
def g(cls, x):
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
def gaussian_kernel(kernel_size: int, sigma: float, device=None):
def gaussian_kernel(kernel_size: int, sigma: float, device=None, dtype=torch.float32):
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij")
d = torch.sqrt(x * x + y * y)
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
return g / g.sum()
return (g / g.sum()).to(dtype)
class Blur(io.ComfyNode):
@classmethod
@@ -99,7 +99,7 @@ class Blur(io.ComfyNode):
batch_size, height, width, channels = image.shape
kernel_size = blur_radius * 2 + 1
kernel = gaussian_kernel(kernel_size, sigma, device=image.device).repeat(channels, 1, 1).unsqueeze(1)
kernel = gaussian_kernel(kernel_size, sigma, device=image.device, dtype=image.dtype).repeat(channels, 1, 1).unsqueeze(1)
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
@@ -200,7 +200,7 @@ class Sharpen(io.ComfyNode):
image = image.to(comfy.model_management.get_torch_device())
kernel_size = sharpen_radius * 2 + 1
kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
kernel = gaussian_kernel(kernel_size, sigma, device=image.device, dtype=image.dtype) * -(alpha*10)
kernel = kernel.to(dtype=image.dtype)
center = kernel_size // 2
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0

View File

@@ -15,6 +15,7 @@ class TextGenerate(io.ComfyNode):
io.Float.Input("min_p", default=0.05, min=0.0, max=1.0, step=0.01),
io.Float.Input("repetition_penalty", default=1.05, min=0.0, max=5.0, step=0.01),
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
io.Float.Input("presence_penalty", optional=True, default=0.0, min=0.0, max=5.0, step=0.01),
]
),
io.DynamicCombo.Option(
@@ -25,7 +26,7 @@ class TextGenerate(io.ComfyNode):
return io.Schema(
node_id="TextGenerate",
category="textgen/",
category="textgen",
search_aliases=["LLM", "gemma"],
inputs=[
io.Clip.Input("clip"),
@@ -33,6 +34,7 @@ class TextGenerate(io.ComfyNode):
io.Image.Input("image", optional=True),
io.Int.Input("max_length", default=256, min=1, max=2048),
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
],
outputs=[
io.String.Output(display_name="generated_text"),
@@ -40,9 +42,9 @@ class TextGenerate(io.ComfyNode):
)
@classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1)
tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1, thinking=thinking)
# Get sampling parameters from dynamic combo
do_sample = sampling_mode.get("sampling_mode") == "on"
@@ -52,6 +54,7 @@ class TextGenerate(io.ComfyNode):
min_p = sampling_mode.get("min_p", 0.0)
seed = sampling_mode.get("seed", None)
repetition_penalty = sampling_mode.get("repetition_penalty", 1.0)
presence_penalty = sampling_mode.get("presence_penalty", 0.0)
generated_ids = clip.generate(
tokens,
@@ -62,6 +65,7 @@ class TextGenerate(io.ComfyNode):
top_p=top_p,
min_p=min_p,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
seed=seed
)
@@ -156,12 +160,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
)
@classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
if image is None:
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
else:
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image)
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking)
class TextgenExtension(ComfyExtension):

View File

@@ -411,6 +411,19 @@ def format_value(x):
else:
return str(x)
def _is_intermediate_output(dynprompt, node_id):
class_type = dynprompt.get_node(node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs):
if server.client_id is None:
return
cached_ui = cached.ui or {}
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id)
if cached.ui is not None:
ui_outputs[node_id] = cached.ui
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
@@ -421,11 +434,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
cached = await caches.outputs.get(unique_id)
if cached is not None:
if server.client_id is not None:
cached_ui = cached.ui or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
if cached.ui is not None:
ui_outputs[unique_id] = cached.ui
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs)
get_progress_state().finish_progress(unique_id)
execution_list.cache_update(unique_id, cached)
return (ExecutionResult.SUCCESS, None, None)
@@ -715,6 +724,9 @@ class PromptExecutor:
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
self._notify_prompt_lifecycle("start", prompt_id)
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None
comfy.memory_management.set_ram_cache_release_state(ram_release_callback, ram_headroom)
try:
with torch.inference_mode():
@@ -764,9 +776,22 @@ class PromptExecutor:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
if self.cache_type == CacheType.RAM_PRESSURE:
comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom)
comfy.memory_management.extra_ram_release(ram_headroom)
else:
# Only execute when the while-loop ends without break
# Send cached UI for intermediate output nodes that weren't executed
for node_id in dynamic_prompt.all_node_ids():
if node_id in executed:
continue
if not _is_intermediate_output(dynamic_prompt, node_id):
continue
cached = await self.caches.outputs.get(node_id)
if cached is not None:
display_node_id = dynamic_prompt.get_display_node_id(node_id)
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs)
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
ui_outputs = {}
@@ -782,6 +807,7 @@ class PromptExecutor:
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()
finally:
comfy.memory_management.set_ram_cache_release_state(None, 0)
self._notify_prompt_lifecycle("end", prompt_id)

17
main.py
View File

@@ -139,16 +139,7 @@ def execute_prestartup_script():
spec.loader.exec_module(module)
return True
except Exception as e:
import traceback
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
from nodes import NODE_STARTUP_ERRORS, get_module_name
node_module_name = get_module_name(os.path.dirname(script_path))
NODE_STARTUP_ERRORS[node_module_name] = {
"module_path": os.path.dirname(script_path),
"error": str(e),
"traceback": traceback.format_exc(),
"phase": "prestartup",
}
return False
node_paths = folder_paths.get_folder_paths("custom_nodes")
@@ -284,15 +275,19 @@ def _collect_output_absolute_paths(history_result: dict) -> list[str]:
def prompt_worker(q, server_instance):
current_time: float = 0.0
cache_ram = args.cache_ram
if cache_ram < 0:
cache_ram = min(32.0, max(4.0, comfy.model_management.total_ram * 0.25 / 1024.0))
cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0:
cache_type = execution.CacheType.LRU
elif args.cache_ram > 0:
elif cache_ram > 0:
cache_type = execution.CacheType.RAM_PRESSURE
elif args.cache_none:
cache_type = execution.CacheType.NONE
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram } )
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0

View File

@@ -1 +1 @@
comfyui_manager==4.1b8
comfyui_manager==4.1

View File

@@ -2181,9 +2181,6 @@ EXTENSION_WEB_DIRS = {}
# Dictionary of successfully loaded module names and associated directories.
LOADED_MODULE_DIRS = {}
# Dictionary of custom node startup errors, keyed by module name.
NODE_STARTUP_ERRORS: dict[str, dict] = {}
def get_module_name(module_path: str) -> str:
"""
@@ -2301,13 +2298,6 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
except Exception as e:
logging.warning(traceback.format_exc())
logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
module_name = get_module_name(module_path)
NODE_STARTUP_ERRORS[module_name] = {
"module_path": module_path,
"error": str(e),
"traceback": traceback.format_exc(),
"phase": "import",
}
return False
async def init_external_custom_nodes():

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.42.8
comfyui-workflow-templates==0.9.36
comfyui-workflow-templates==0.9.39
comfyui-embedded-docs==0.4.3
torch
torchsde

View File

@@ -709,6 +709,11 @@ class PromptServer():
else:
info['output_node'] = False
if hasattr(obj_class, 'HAS_INTERMEDIATE_OUTPUT') and obj_class.HAS_INTERMEDIATE_OUTPUT == True:
info['has_intermediate_output'] = True
else:
info['has_intermediate_output'] = False
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
@@ -753,10 +758,6 @@ class PromptServer():
out[node_class] = node_info(node_class)
return web.json_response(out)
@routes.get("/custom_node_startup_errors")
async def get_custom_node_startup_errors(request):
return web.json_response(nodes.NODE_STARTUP_ERRORS)
@routes.get("/api/jobs")
async def get_jobs(request):
"""List all jobs with filtering, sorting, and pagination.

View File

@@ -0,0 +1,81 @@
"""Tests for path_utils asset category resolution."""
import os
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
from app.assets.services.path_utils import get_asset_category_and_relative_path
@pytest.fixture
def fake_dirs():
"""Create temporary input, output, and temp directories."""
with tempfile.TemporaryDirectory() as root:
root_path = Path(root)
input_dir = root_path / "input"
output_dir = root_path / "output"
temp_dir = root_path / "temp"
models_dir = root_path / "models" / "checkpoints"
for d in (input_dir, output_dir, temp_dir, models_dir):
d.mkdir(parents=True)
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
mock_fp.get_input_directory.return_value = str(input_dir)
mock_fp.get_output_directory.return_value = str(output_dir)
mock_fp.get_temp_directory.return_value = str(temp_dir)
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("checkpoints", [str(models_dir)])],
):
yield {
"input": input_dir,
"output": output_dir,
"temp": temp_dir,
"models": models_dir,
}
class TestGetAssetCategoryAndRelativePath:
def test_input_file(self, fake_dirs):
f = fake_dirs["input"] / "photo.png"
f.touch()
cat, rel = get_asset_category_and_relative_path(str(f))
assert cat == "input"
assert rel == "photo.png"
def test_output_file(self, fake_dirs):
f = fake_dirs["output"] / "result.png"
f.touch()
cat, rel = get_asset_category_and_relative_path(str(f))
assert cat == "output"
assert rel == "result.png"
def test_temp_file(self, fake_dirs):
"""Regression: temp files must be categorised, not raise ValueError."""
f = fake_dirs["temp"] / "GLSLShader_output_00004_.png"
f.touch()
cat, rel = get_asset_category_and_relative_path(str(f))
assert cat == "temp"
assert rel == "GLSLShader_output_00004_.png"
def test_temp_file_in_subfolder(self, fake_dirs):
sub = fake_dirs["temp"] / "sub"
sub.mkdir()
f = sub / "ComfyUI_temp_tczip_00004_.png"
f.touch()
cat, rel = get_asset_category_and_relative_path(str(f))
assert cat == "temp"
assert os.path.normpath(rel) == os.path.normpath("sub/ComfyUI_temp_tczip_00004_.png")
def test_model_file(self, fake_dirs):
f = fake_dirs["models"] / "model.safetensors"
f.touch()
cat, rel = get_asset_category_and_relative_path(str(f))
assert cat == "models"
def test_unknown_path_raises(self, fake_dirs):
with pytest.raises(ValueError, match="not within"):
get_asset_category_and_relative_path("/some/random/path.png")

View File

@@ -90,6 +90,63 @@ class TestNumberConvertExecute:
assert result[0] == 1000.0
assert result[1] == 1000
# --- Large number precision (string input) ---
def test_string_large_int_above_2_53(self):
"""Text-to-int must not lose precision for integers beyond 2^53."""
big = 2**53 + 1 # 9007199254740993
result = self._exec(str(big))
assert result[1] == big
def test_string_large_negative_int_above_2_53(self):
big = -(2**53 + 1)
result = self._exec(str(big))
assert result[1] == big
def test_string_very_large_int(self):
big = 2**63 + 42
result = self._exec(str(big))
assert result[1] == big
def test_string_large_int_float_output_is_float(self):
"""FLOAT output is still a float (may lose precision, but must be float type)."""
result = self._exec(str(2**53 + 1))
assert isinstance(result[0], float)
# --- Large number precision (int input) ---
def test_int_large_above_2_53(self):
"""Native int input must preserve its value in the INT output."""
big = 2**53 + 1
result = self._exec(big)
assert result[1] == big
def test_int_large_negative_above_2_53(self):
big = -(2**53 + 1)
result = self._exec(big)
assert result[1] == big
def test_int_very_large(self):
big = 2**100
result = self._exec(big)
assert result[1] == big
# --- String decimal / scientific notation fallback ---
def test_string_decimal_still_truncates(self):
"""Strings with decimal points fall back to int(float(...)) truncation."""
result = self._exec("3.7")
assert result[1] == 3
def test_string_negative_decimal_truncates(self):
result = self._exec("-2.9")
assert result[1] == -2
def test_string_scientific_large(self):
result = self._exec("1e18")
assert result[0] == 1e18
assert result[1] == 10**18
# --- STRING error paths ---
def test_empty_string_raises(self):

View File

@@ -24,6 +24,7 @@ def init_mime_types():
# Web types (used by server.py for static file serving)
mimetypes.add_type('application/javascript; charset=utf-8', '.js')
mimetypes.add_type('image/webp', '.webp')
mimetypes.add_type('image/svg+xml', '.svg')
# Model and data file types (used by asset scanning / metadata extraction)
mimetypes.add_type("application/safetensors", ".safetensors")