Compare commits

..

1 Commits

Author SHA1 Message Date
Johnpaul
3278ff75fd feat: replay execution context on WS reconnect
Store current execution state (prompt_id, cached nodes, executed nodes,
outputs to execute) on the server instance during execution. On WS
reconnect, replay execution_start, execution_cached, progress_state,
and executing events so the frontend can restore progress tracking.

Refactor ProgressRegistry to expose get_serialized_state() and reuse
it in WebUIProgressHandler._send_progress_state().
2026-02-16 23:53:50 +01:00
192 changed files with 969 additions and 7083 deletions

View File

@@ -1,127 +0,0 @@
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
language: "en-US"
early_access: false
tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code."
reviews:
profile: "chill"
request_changes_workflow: false
high_level_summary: false
poem: false
review_status: false
review_details: false
commit_status: true
collapse_walkthrough: true
changed_files_summary: false
sequence_diagrams: false
estimate_code_review_effort: false
assess_linked_issues: false
related_issues: false
related_prs: false
suggested_labels: false
auto_apply_labels: false
suggested_reviewers: false
auto_assign_reviewers: false
in_progress_fortune: false
enable_prompt_for_ai_agents: true
path_filters:
- "!comfy_api_nodes/apis/**"
- "!**/generated/*.pyi"
- "!.ci/**"
- "!script_examples/**"
- "!**/__pycache__/**"
- "!**/*.ipynb"
- "!**/*.png"
- "!**/*.bat"
path_instructions:
- path: "**"
instructions: |
IMPORTANT: Only comment on issues directly introduced by this PR's code changes.
Do NOT flag pre-existing issues in code that was merely moved, re-indented,
de-indented, or reformatted without logic changes. If code appears in the diff
only due to whitespace or structural reformatting (e.g., removing a `with:` block),
treat it as unchanged. Contributors should not feel obligated to address
pre-existing issues outside the scope of their contribution.
- path: "comfy/**"
instructions: |
Core ML/diffusion engine. Focus on:
- Backward compatibility (breaking changes affect all custom nodes)
- Memory management and GPU resource handling
- Performance implications in hot paths
- Thread safety for concurrent execution
- path: "comfy_api_nodes/**"
instructions: |
Third-party API integration nodes. Focus on:
- No hardcoded API keys or secrets
- Proper error handling for API failures (timeouts, rate limits, auth errors)
- Correct Pydantic model usage
- Security of user data passed to external APIs
- path: "comfy_extras/**"
instructions: |
Community-contributed extra nodes. Focus on:
- Consistency with node patterns (INPUT_TYPES, RETURN_TYPES, FUNCTION, CATEGORY)
- No breaking changes to existing node interfaces
- path: "comfy_execution/**"
instructions: |
Execution engine (graph execution, caching, jobs). Focus on:
- Caching correctness
- Concurrent execution safety
- Graph validation edge cases
- path: "nodes.py"
instructions: |
Core node definitions (2500+ lines). Focus on:
- Backward compatibility of NODE_CLASS_MAPPINGS
- Consistency of INPUT_TYPES return format
- path: "alembic_db/**"
instructions: |
Database migrations. Focus on:
- Migration safety and rollback support
- Data preservation during schema changes
auto_review:
enabled: true
auto_incremental_review: true
drafts: false
ignore_title_keywords:
- "WIP"
- "DO NOT REVIEW"
- "DO NOT MERGE"
finishing_touches:
docstrings:
enabled: false
unit_tests:
enabled: false
tools:
ruff:
enabled: false
pylint:
enabled: false
flake8:
enabled: false
gitleaks:
enabled: true
shellcheck:
enabled: false
markdownlint:
enabled: false
yamllint:
enabled: false
languagetool:
enabled: false
github-checks:
enabled: true
timeout_ms: 90000
ast-grep:
essential_rules: true
chat:
auto_reply: true
knowledge_base:
opt_out: false
learnings:
scope: "auto"

2
.gitignore vendored
View File

@@ -11,7 +11,7 @@ extra_model_paths.yaml
/.vs
.vscode/
.idea/
venv*/
venv/
.venv/
/web/extensions/*
!/web/extensions/logging.js.example

View File

@@ -227,11 +227,11 @@ Put your VAE in: models/vae
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
This is the command to install the nightly with ROCm 7.2 which might have some performance improvements:
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.2```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.

View File

@@ -1,107 +0,0 @@
from __future__ import annotations
from aiohttp import web
from typing import TYPE_CHECKING, TypedDict
if TYPE_CHECKING:
from comfy_api.latest._io_public import NodeReplace
from comfy_execution.graph_utils import is_link
import nodes
class NodeStruct(TypedDict):
inputs: dict[str, str | int | float | bool | tuple[str, int]]
class_type: str
_meta: dict[str, str]
def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
new_node_struct = node_struct.copy()
if empty_inputs:
new_node_struct["inputs"] = {}
else:
new_node_struct["inputs"] = node_struct["inputs"].copy()
new_node_struct["_meta"] = node_struct["_meta"].copy()
return new_node_struct
class NodeReplaceManager:
"""Manages node replacement registrations."""
def __init__(self):
self._replacements: dict[str, list[NodeReplace]] = {}
def register(self, node_replace: NodeReplace):
"""Register a node replacement mapping."""
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
"""Get replacements for an old node ID."""
return self._replacements.get(old_node_id)
def has_replacement(self, old_node_id: str) -> bool:
"""Check if a replacement exists for an old node ID."""
return old_node_id in self._replacements
def apply_replacements(self, prompt: dict[str, NodeStruct]):
connections: dict[str, list[tuple[str, str, int]]] = {}
need_replacement: set[str] = set()
for node_number, node_struct in prompt.items():
if "class_type" not in node_struct or "inputs" not in node_struct:
continue
class_type = node_struct["class_type"]
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
need_replacement.add(node_number)
# keep track of connections
for input_id, input_value in node_struct["inputs"].items():
if is_link(input_value):
conn_number = input_value[0]
connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
for node_number in need_replacement:
node_struct = prompt[node_number]
class_type = node_struct["class_type"]
replacements = self.get_replacement(class_type)
if replacements is None:
continue
# just use the first replacement
replacement = replacements[0]
new_node_id = replacement.new_node_id
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
continue
# first, replace node id (class_type)
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
new_node_struct["class_type"] = new_node_id
# TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
# second, replace inputs
if replacement.input_mapping is not None:
for input_map in replacement.input_mapping:
if "set_value" in input_map:
new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
elif "old_id" in input_map:
new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
# finalize input replacement
prompt[node_number] = new_node_struct
# third, replace outputs
if replacement.output_mapping is not None:
# re-mapping outputs requires changing the input values of nodes that receive connections from this one
if node_number in connections:
for conns in connections[node_number]:
conn_node_number, conn_input_id, old_output_idx = conns
for output_map in replacement.output_mapping:
if output_map["old_idx"] == old_output_idx:
new_output_idx = output_map["new_idx"]
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
previous_input[1] = new_output_idx
def as_dict(self):
"""Serialize all replacements to dict."""
return {
k: [v.as_dict() for v in v_list]
for k, v_list in self._replacements.items()
}
def add_routes(self, routes):
@routes.get("/node_replacements")
async def get_node_replacements(request):
return web.json_response(self.as_dict())

View File

@@ -53,7 +53,7 @@ class SubgraphManager:
return entry_id, entry
async def load_entry_data(self, entry: SubgraphEntry):
with open(entry['path'], 'r', encoding='utf-8') as f:
with open(entry['path'], 'r') as f:
entry['data'] = f.read()
return entry

View File

@@ -1,44 +0,0 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform float u_float0; // Brightness slider -100..100
uniform float u_float1; // Contrast slider -100..100
in vec2 v_texCoord;
out vec4 fragColor;
const float MID_GRAY = 0.18; // 18% reflectance
// sRGB gamma 2.2 approximation
vec3 srgbToLinear(vec3 c) {
return pow(max(c, 0.0), vec3(2.2));
}
vec3 linearToSrgb(vec3 c) {
return pow(max(c, 0.0), vec3(1.0/2.2));
}
float mapBrightness(float b) {
return clamp(b / 100.0, -1.0, 1.0);
}
float mapContrast(float c) {
return clamp(c / 100.0 + 1.0, 0.0, 2.0);
}
void main() {
vec4 orig = texture(u_image0, v_texCoord);
float brightness = mapBrightness(u_float0);
float contrast = mapContrast(u_float1);
vec3 lin = srgbToLinear(orig.rgb);
lin = (lin - MID_GRAY) * contrast + brightness + MID_GRAY;
// Convert back to sRGB
vec3 result = linearToSrgb(clamp(lin, 0.0, 1.0));
fragColor = vec4(result, orig.a);
}

View File

@@ -1,72 +0,0 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform int u_int0; // Mode
uniform float u_float0; // Amount (0 to 100)
in vec2 v_texCoord;
out vec4 fragColor;
const int MODE_LINEAR = 0;
const int MODE_RADIAL = 1;
const int MODE_BARREL = 2;
const int MODE_SWIRL = 3;
const int MODE_DIAGONAL = 4;
const float AMOUNT_SCALE = 0.0005;
const float RADIAL_MULT = 4.0;
const float BARREL_MULT = 8.0;
const float INV_SQRT2 = 0.70710678118;
void main() {
vec2 uv = v_texCoord;
vec4 original = texture(u_image0, uv);
float amount = u_float0 * AMOUNT_SCALE;
if (amount < 0.000001) {
fragColor = original;
return;
}
// Aspect-corrected coordinates for circular effects
float aspect = u_resolution.x / u_resolution.y;
vec2 centered = uv - 0.5;
vec2 corrected = vec2(centered.x * aspect, centered.y);
float r = length(corrected);
vec2 dir = r > 0.0001 ? corrected / r : vec2(0.0);
vec2 offset = vec2(0.0);
if (u_int0 == MODE_LINEAR) {
// Horizontal shift (no aspect correction needed)
offset = vec2(amount, 0.0);
}
else if (u_int0 == MODE_RADIAL) {
// Outward from center, stronger at edges
offset = dir * r * amount * RADIAL_MULT;
offset.x /= aspect; // Convert back to UV space
}
else if (u_int0 == MODE_BARREL) {
// Lens distortion simulation (r² falloff)
offset = dir * r * r * amount * BARREL_MULT;
offset.x /= aspect; // Convert back to UV space
}
else if (u_int0 == MODE_SWIRL) {
// Perpendicular to radial (rotational aberration)
vec2 perp = vec2(-dir.y, dir.x);
offset = perp * r * amount * RADIAL_MULT;
offset.x /= aspect; // Convert back to UV space
}
else if (u_int0 == MODE_DIAGONAL) {
// 45° offset (no aspect correction needed)
offset = vec2(amount, amount) * INV_SQRT2;
}
float red = texture(u_image0, uv + offset).r;
float green = original.g;
float blue = texture(u_image0, uv - offset).b;
fragColor = vec4(red, green, blue, original.a);
}

View File

@@ -1,78 +0,0 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform float u_float0; // temperature (-100 to 100)
uniform float u_float1; // tint (-100 to 100)
uniform float u_float2; // vibrance (-100 to 100)
uniform float u_float3; // saturation (-100 to 100)
in vec2 v_texCoord;
out vec4 fragColor;
const float INPUT_SCALE = 0.01;
const float TEMP_TINT_PRIMARY = 0.3;
const float TEMP_TINT_SECONDARY = 0.15;
const float VIBRANCE_BOOST = 2.0;
const float SATURATION_BOOST = 2.0;
const float SKIN_PROTECTION = 0.5;
const float EPSILON = 0.001;
const vec3 LUMA_WEIGHTS = vec3(0.299, 0.587, 0.114);
void main() {
vec4 tex = texture(u_image0, v_texCoord);
vec3 color = tex.rgb;
// Scale inputs: -100/100 → -1/1
float temperature = u_float0 * INPUT_SCALE;
float tint = u_float1 * INPUT_SCALE;
float vibrance = u_float2 * INPUT_SCALE;
float saturation = u_float3 * INPUT_SCALE;
// Temperature (warm/cool): positive = warm, negative = cool
color.r += temperature * TEMP_TINT_PRIMARY;
color.b -= temperature * TEMP_TINT_PRIMARY;
// Tint (green/magenta): positive = green, negative = magenta
color.g += tint * TEMP_TINT_PRIMARY;
color.r -= tint * TEMP_TINT_SECONDARY;
color.b -= tint * TEMP_TINT_SECONDARY;
// Single clamp after temperature/tint
color = clamp(color, 0.0, 1.0);
// Vibrance with skin protection
if (vibrance != 0.0) {
float maxC = max(color.r, max(color.g, color.b));
float minC = min(color.r, min(color.g, color.b));
float sat = maxC - minC;
float gray = dot(color, LUMA_WEIGHTS);
if (vibrance < 0.0) {
// Desaturate: -100 → gray
color = mix(vec3(gray), color, 1.0 + vibrance);
} else {
// Boost less saturated colors more
float vibranceAmt = vibrance * (1.0 - sat);
// Branchless skin tone protection
float isWarmTone = step(color.b, color.g) * step(color.g, color.r);
float warmth = (color.r - color.b) / max(maxC, EPSILON);
float skinTone = isWarmTone * warmth * sat * (1.0 - sat);
vibranceAmt *= (1.0 - skinTone * SKIN_PROTECTION);
color = mix(vec3(gray), color, 1.0 + vibranceAmt * VIBRANCE_BOOST);
}
}
// Saturation
if (saturation != 0.0) {
float gray = dot(color, LUMA_WEIGHTS);
float satMix = saturation < 0.0
? 1.0 + saturation // -100 → gray
: 1.0 + saturation * SATURATION_BOOST; // +100 → 3x boost
color = mix(vec3(gray), color, satMix);
}
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
}

View File

@@ -1,94 +0,0 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform float u_float0; // Blur radius (020, default ~5)
uniform float u_float1; // Edge threshold (0100, default ~30)
uniform int u_int0; // Step size (0/1 = every pixel, 2+ = skip pixels)
in vec2 v_texCoord;
out vec4 fragColor;
const int MAX_RADIUS = 20;
const float EPSILON = 0.0001;
// Perceptual luminance
float getLuminance(vec3 rgb) {
return dot(rgb, vec3(0.299, 0.587, 0.114));
}
vec4 bilateralFilter(vec2 uv, vec2 texelSize, int radius,
float sigmaSpatial, float sigmaColor)
{
vec4 center = texture(u_image0, uv);
vec3 centerRGB = center.rgb;
float invSpatial2 = -0.5 / (sigmaSpatial * sigmaSpatial);
float invColor2 = -0.5 / (sigmaColor * sigmaColor + EPSILON);
vec3 sumRGB = vec3(0.0);
float sumWeight = 0.0;
int step = max(u_int0, 1);
float radius2 = float(radius * radius);
for (int dy = -MAX_RADIUS; dy <= MAX_RADIUS; dy++) {
if (dy < -radius || dy > radius) continue;
if (abs(dy) % step != 0) continue;
for (int dx = -MAX_RADIUS; dx <= MAX_RADIUS; dx++) {
if (dx < -radius || dx > radius) continue;
if (abs(dx) % step != 0) continue;
vec2 offset = vec2(float(dx), float(dy));
float dist2 = dot(offset, offset);
if (dist2 > radius2) continue;
vec3 sampleRGB = texture(u_image0, uv + offset * texelSize).rgb;
// Spatial Gaussian
float spatialWeight = exp(dist2 * invSpatial2);
// Perceptual color distance (weighted RGB)
vec3 diff = sampleRGB - centerRGB;
float colorDist = dot(diff * diff, vec3(0.299, 0.587, 0.114));
float colorWeight = exp(colorDist * invColor2);
float w = spatialWeight * colorWeight;
sumRGB += sampleRGB * w;
sumWeight += w;
}
}
vec3 resultRGB = sumRGB / max(sumWeight, EPSILON);
return vec4(resultRGB, center.a); // preserve center alpha
}
void main() {
vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));
float radiusF = clamp(u_float0, 0.0, float(MAX_RADIUS));
int radius = int(radiusF + 0.5);
if (radius == 0) {
fragColor = texture(u_image0, v_texCoord);
return;
}
// Edge threshold → color sigma
// Squared curve for better low-end control
float t = clamp(u_float1, 0.0, 100.0) / 100.0;
t *= t;
float sigmaColor = mix(0.01, 0.5, t);
// Spatial sigma tied to radius
float sigmaSpatial = max(radiusF * 0.75, 0.5);
fragColor = bilateralFilter(
v_texCoord,
texelSize,
radius,
sigmaSpatial,
sigmaColor
);
}

View File

@@ -1,124 +0,0 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform float u_float0; // grain amount [0.0 1.0] typical: 0.20.8
uniform float u_float1; // grain size [0.3 3.0] lower = finer grain
uniform float u_float2; // color amount [0.0 1.0] 0 = monochrome, 1 = RGB grain
uniform float u_float3; // luminance bias [0.0 1.0] 0 = uniform, 1 = shadows only
uniform int u_int0; // noise mode [0 or 1] 0 = smooth, 1 = grainy
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
// High-quality integer hash (pcg-like)
uint pcg(uint v) {
uint state = v * 747796405u + 2891336453u;
uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
return (word >> 22u) ^ word;
}
// 2D -> 1D hash input
uint hash2d(uvec2 p) {
return pcg(p.x + pcg(p.y));
}
// Hash to float [0, 1]
float hashf(uvec2 p) {
return float(hash2d(p)) / float(0xffffffffu);
}
// Hash to float with offset (for RGB channels)
float hashf(uvec2 p, uint offset) {
return float(pcg(hash2d(p) + offset)) / float(0xffffffffu);
}
// Convert uniform [0,1] to roughly Gaussian distribution
// Using simple approximation: average of multiple samples
float toGaussian(uvec2 p) {
float sum = hashf(p, 0u) + hashf(p, 1u) + hashf(p, 2u) + hashf(p, 3u);
return (sum - 2.0) * 0.7; // Centered, scaled
}
float toGaussian(uvec2 p, uint offset) {
float sum = hashf(p, offset) + hashf(p, offset + 1u)
+ hashf(p, offset + 2u) + hashf(p, offset + 3u);
return (sum - 2.0) * 0.7;
}
// Smooth noise with better interpolation
float smoothNoise(vec2 p) {
vec2 i = floor(p);
vec2 f = fract(p);
// Quintic interpolation (less banding than cubic)
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
uvec2 ui = uvec2(i);
float a = toGaussian(ui);
float b = toGaussian(ui + uvec2(1u, 0u));
float c = toGaussian(ui + uvec2(0u, 1u));
float d = toGaussian(ui + uvec2(1u, 1u));
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
}
float smoothNoise(vec2 p, uint offset) {
vec2 i = floor(p);
vec2 f = fract(p);
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
uvec2 ui = uvec2(i);
float a = toGaussian(ui, offset);
float b = toGaussian(ui + uvec2(1u, 0u), offset);
float c = toGaussian(ui + uvec2(0u, 1u), offset);
float d = toGaussian(ui + uvec2(1u, 1u), offset);
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
}
void main() {
vec4 color = texture(u_image0, v_texCoord);
// Luminance (Rec.709)
float luma = dot(color.rgb, vec3(0.2126, 0.7152, 0.0722));
// Grain UV (resolution-independent)
vec2 grainUV = v_texCoord * u_resolution / max(u_float1, 0.01);
uvec2 grainPixel = uvec2(grainUV);
float g;
vec3 grainRGB;
if (u_int0 == 1) {
// Grainy mode: pure hash noise (no interpolation = no banding)
g = toGaussian(grainPixel);
grainRGB = vec3(
toGaussian(grainPixel, 100u),
toGaussian(grainPixel, 200u),
toGaussian(grainPixel, 300u)
);
} else {
// Smooth mode: interpolated with quintic curve
g = smoothNoise(grainUV);
grainRGB = vec3(
smoothNoise(grainUV, 100u),
smoothNoise(grainUV, 200u),
smoothNoise(grainUV, 300u)
);
}
// Luminance weighting (less grain in highlights)
float lumWeight = mix(1.0, 1.0 - luma, clamp(u_float3, 0.0, 1.0));
// Strength
float strength = u_float0 * 0.15;
// Color vs monochrome grain
vec3 grainColor = mix(vec3(g), grainRGB, clamp(u_float2, 0.0, 1.0));
color.rgb += grainColor * strength * lumWeight;
fragColor0 = vec4(clamp(color.rgb, 0.0, 1.0), color.a);
}

View File

@@ -1,133 +0,0 @@
#version 300 es
precision mediump float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform int u_int0; // Blend mode
uniform int u_int1; // Color tint
uniform float u_float0; // Intensity
uniform float u_float1; // Radius
uniform float u_float2; // Threshold
in vec2 v_texCoord;
out vec4 fragColor;
const int BLEND_ADD = 0;
const int BLEND_SCREEN = 1;
const int BLEND_SOFT = 2;
const int BLEND_OVERLAY = 3;
const int BLEND_LIGHTEN = 4;
const float GOLDEN_ANGLE = 2.39996323;
const int MAX_SAMPLES = 48;
const vec3 LUMA = vec3(0.299, 0.587, 0.114);
float hash(vec2 p) {
p = fract(p * vec2(123.34, 456.21));
p += dot(p, p + 45.32);
return fract(p.x * p.y);
}
vec3 hexToRgb(int h) {
return vec3(
float((h >> 16) & 255),
float((h >> 8) & 255),
float(h & 255)
) * (1.0 / 255.0);
}
vec3 blend(vec3 base, vec3 glow, int mode) {
if (mode == BLEND_SCREEN) {
return 1.0 - (1.0 - base) * (1.0 - glow);
}
if (mode == BLEND_SOFT) {
return mix(
base - (1.0 - 2.0 * glow) * base * (1.0 - base),
base + (2.0 * glow - 1.0) * (sqrt(base) - base),
step(0.5, glow)
);
}
if (mode == BLEND_OVERLAY) {
return mix(
2.0 * base * glow,
1.0 - 2.0 * (1.0 - base) * (1.0 - glow),
step(0.5, base)
);
}
if (mode == BLEND_LIGHTEN) {
return max(base, glow);
}
return base + glow;
}
void main() {
vec4 original = texture(u_image0, v_texCoord);
float intensity = u_float0 * 0.05;
float radius = u_float1 * u_float1 * 0.012;
if (intensity < 0.001 || radius < 0.1) {
fragColor = original;
return;
}
float threshold = 1.0 - u_float2 * 0.01;
float t0 = threshold - 0.15;
float t1 = threshold + 0.15;
vec2 texelSize = 1.0 / u_resolution;
float radius2 = radius * radius;
float sampleScale = clamp(radius * 0.75, 0.35, 1.0);
int samples = int(float(MAX_SAMPLES) * sampleScale);
float noise = hash(gl_FragCoord.xy);
float angleOffset = noise * GOLDEN_ANGLE;
float radiusJitter = 0.85 + noise * 0.3;
float ca = cos(GOLDEN_ANGLE);
float sa = sin(GOLDEN_ANGLE);
vec2 dir = vec2(cos(angleOffset), sin(angleOffset));
vec3 glow = vec3(0.0);
float totalWeight = 0.0;
// Center tap
float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));
glow += original.rgb * centerMask * 2.0;
totalWeight += 2.0;
for (int i = 1; i < MAX_SAMPLES; i++) {
if (i >= samples) break;
float fi = float(i);
float dist = sqrt(fi / float(samples)) * radius * radiusJitter;
vec2 offset = dir * dist * texelSize;
vec3 c = texture(u_image0, v_texCoord + offset).rgb;
float mask = smoothstep(t0, t1, dot(c, LUMA));
float w = 1.0 - (dist * dist) / (radius2 * 1.5);
w = max(w, 0.0);
w *= w;
glow += c * mask * w;
totalWeight += w;
dir = vec2(
dir.x * ca - dir.y * sa,
dir.x * sa + dir.y * ca
);
}
glow *= intensity / max(totalWeight, 0.001);
if (u_int1 > 0) {
glow *= hexToRgb(u_int1);
}
vec3 result = blend(original.rgb, glow, u_int0);
result += (noise - 0.5) * (1.0 / 255.0);
fragColor = vec4(clamp(result, 0.0, 1.0), original.a);
}

View File

@@ -1,222 +0,0 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform int u_int0; // Mode: 0=Master, 1=Reds, 2=Yellows, 3=Greens, 4=Cyans, 5=Blues, 6=Magentas, 7=Colorize
uniform int u_int1; // Color Space: 0=HSL, 1=HSB/HSV
uniform float u_float0; // Hue (-180 to 180)
uniform float u_float1; // Saturation (-100 to 100)
uniform float u_float2; // Lightness/Brightness (-100 to 100)
uniform float u_float3; // Overlap (0 to 100) - feathering between adjacent color ranges
in vec2 v_texCoord;
out vec4 fragColor;
// Color range modes
const int MODE_MASTER = 0;
const int MODE_RED = 1;
const int MODE_YELLOW = 2;
const int MODE_GREEN = 3;
const int MODE_CYAN = 4;
const int MODE_BLUE = 5;
const int MODE_MAGENTA = 6;
const int MODE_COLORIZE = 7;
// Color space modes
const int COLORSPACE_HSL = 0;
const int COLORSPACE_HSB = 1;
const float EPSILON = 0.0001;
//=============================================================================
// RGB <-> HSL Conversions
//=============================================================================
vec3 rgb2hsl(vec3 c) {
float maxC = max(max(c.r, c.g), c.b);
float minC = min(min(c.r, c.g), c.b);
float delta = maxC - minC;
float h = 0.0;
float s = 0.0;
float l = (maxC + minC) * 0.5;
if (delta > EPSILON) {
s = l < 0.5
? delta / (maxC + minC)
: delta / (2.0 - maxC - minC);
if (maxC == c.r) {
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
} else if (maxC == c.g) {
h = (c.b - c.r) / delta + 2.0;
} else {
h = (c.r - c.g) / delta + 4.0;
}
h /= 6.0;
}
return vec3(h, s, l);
}
float hue2rgb(float p, float q, float t) {
t = fract(t);
if (t < 1.0/6.0) return p + (q - p) * 6.0 * t;
if (t < 0.5) return q;
if (t < 2.0/3.0) return p + (q - p) * (2.0/3.0 - t) * 6.0;
return p;
}
vec3 hsl2rgb(vec3 hsl) {
if (hsl.y < EPSILON) return vec3(hsl.z);
float q = hsl.z < 0.5
? hsl.z * (1.0 + hsl.y)
: hsl.z + hsl.y - hsl.z * hsl.y;
float p = 2.0 * hsl.z - q;
return vec3(
hue2rgb(p, q, hsl.x + 1.0/3.0),
hue2rgb(p, q, hsl.x),
hue2rgb(p, q, hsl.x - 1.0/3.0)
);
}
vec3 rgb2hsb(vec3 c) {
float maxC = max(max(c.r, c.g), c.b);
float minC = min(min(c.r, c.g), c.b);
float delta = maxC - minC;
float h = 0.0;
float s = (maxC > EPSILON) ? delta / maxC : 0.0;
float b = maxC;
if (delta > EPSILON) {
if (maxC == c.r) {
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
} else if (maxC == c.g) {
h = (c.b - c.r) / delta + 2.0;
} else {
h = (c.r - c.g) / delta + 4.0;
}
h /= 6.0;
}
return vec3(h, s, b);
}
vec3 hsb2rgb(vec3 hsb) {
vec3 rgb = clamp(abs(mod(hsb.x * 6.0 + vec3(0.0, 4.0, 2.0), 6.0) - 3.0) - 1.0, 0.0, 1.0);
return hsb.z * mix(vec3(1.0), rgb, hsb.y);
}
//=============================================================================
// Color Range Weight Calculation
//=============================================================================
float hueDistance(float a, float b) {
float d = abs(a - b);
return min(d, 1.0 - d);
}
float getHueWeight(float hue, float center, float overlap) {
float baseWidth = 1.0 / 6.0;
float feather = baseWidth * overlap;
float d = hueDistance(hue, center);
float inner = baseWidth * 0.5;
float outer = inner + feather;
return 1.0 - smoothstep(inner, outer, d);
}
float getModeWeight(float hue, int mode, float overlap) {
if (mode == MODE_MASTER || mode == MODE_COLORIZE) return 1.0;
if (mode == MODE_RED) {
return max(
getHueWeight(hue, 0.0, overlap),
getHueWeight(hue, 1.0, overlap)
);
}
float center = float(mode - 1) / 6.0;
return getHueWeight(hue, center, overlap);
}
//=============================================================================
// Adjustment Functions
//=============================================================================
float adjustLightness(float l, float amount) {
return amount > 0.0
? l + (1.0 - l) * amount
: l + l * amount;
}
float adjustBrightness(float b, float amount) {
return clamp(b + amount, 0.0, 1.0);
}
float adjustSaturation(float s, float amount) {
return amount > 0.0
? s + (1.0 - s) * amount
: s + s * amount;
}
vec3 colorize(vec3 rgb, float hue, float sat, float light) {
float lum = dot(rgb, vec3(0.299, 0.587, 0.114));
float l = adjustLightness(lum, light);
vec3 hsl = vec3(fract(hue), clamp(sat, 0.0, 1.0), clamp(l, 0.0, 1.0));
return hsl2rgb(hsl);
}
//=============================================================================
// Main
//=============================================================================
void main() {
vec4 original = texture(u_image0, v_texCoord);
float hueShift = u_float0 / 360.0; // -180..180 -> -0.5..0.5
float satAmount = u_float1 / 100.0; // -100..100 -> -1..1
float lightAmount= u_float2 / 100.0; // -100..100 -> -1..1
float overlap = u_float3 / 100.0; // 0..100 -> 0..1
vec3 result;
if (u_int0 == MODE_COLORIZE) {
result = colorize(original.rgb, hueShift, satAmount, lightAmount);
fragColor = vec4(result, original.a);
return;
}
vec3 hsx = (u_int1 == COLORSPACE_HSL)
? rgb2hsl(original.rgb)
: rgb2hsb(original.rgb);
float weight = getModeWeight(hsx.x, u_int0, overlap);
if (u_int0 != MODE_MASTER && hsx.y < EPSILON) {
weight = 0.0;
}
if (weight > EPSILON) {
float h = fract(hsx.x + hueShift * weight);
float s = clamp(adjustSaturation(hsx.y, satAmount * weight), 0.0, 1.0);
float v = (u_int1 == COLORSPACE_HSL)
? clamp(adjustLightness(hsx.z, lightAmount * weight), 0.0, 1.0)
: clamp(adjustBrightness(hsx.z, lightAmount * weight), 0.0, 1.0);
vec3 adjusted = vec3(h, s, v);
result = (u_int1 == COLORSPACE_HSL)
? hsl2rgb(adjusted)
: hsb2rgb(adjusted);
} else {
result = original.rgb;
}
fragColor = vec4(result, original.a);
}

View File

@@ -1,111 +0,0 @@
#version 300 es
#pragma passes 2
precision highp float;
// Blur type constants
const int BLUR_GAUSSIAN = 0;
const int BLUR_BOX = 1;
const int BLUR_RADIAL = 2;
// Radial blur config
const int RADIAL_SAMPLES = 12;
const float RADIAL_STRENGTH = 0.0003;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)
uniform float u_float0; // Blur radius/amount
uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
float gaussian(float x, float sigma) {
return exp(-(x * x) / (2.0 * sigma * sigma));
}
void main() {
vec2 texelSize = 1.0 / u_resolution;
float radius = max(u_float0, 0.0);
// Radial (angular) blur - single pass, doesn't use separable
if (u_int0 == BLUR_RADIAL) {
// Only execute on first pass
if (u_pass > 0) {
fragColor0 = texture(u_image0, v_texCoord);
return;
}
vec2 center = vec2(0.5);
vec2 dir = v_texCoord - center;
float dist = length(dir);
if (dist < 1e-4) {
fragColor0 = texture(u_image0, v_texCoord);
return;
}
vec4 sum = vec4(0.0);
float totalWeight = 0.0;
float angleStep = radius * RADIAL_STRENGTH;
dir /= dist;
float cosStep = cos(angleStep);
float sinStep = sin(angleStep);
float negAngle = -float(RADIAL_SAMPLES) * angleStep;
vec2 rotDir = vec2(
dir.x * cos(negAngle) - dir.y * sin(negAngle),
dir.x * sin(negAngle) + dir.y * cos(negAngle)
);
for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {
vec2 uv = center + rotDir * dist;
float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);
sum += texture(u_image0, uv) * w;
totalWeight += w;
rotDir = vec2(
rotDir.x * cosStep - rotDir.y * sinStep,
rotDir.x * sinStep + rotDir.y * cosStep
);
}
fragColor0 = sum / max(totalWeight, 0.001);
return;
}
// Separable Gaussian / Box blur
int samples = int(ceil(radius));
if (samples == 0) {
fragColor0 = texture(u_image0, v_texCoord);
return;
}
// Direction: pass 0 = horizontal, pass 1 = vertical
vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);
vec4 color = vec4(0.0);
float totalWeight = 0.0;
float sigma = radius / 2.0;
for (int i = -samples; i <= samples; i++) {
vec2 offset = dir * float(i) * texelSize;
vec4 sample_color = texture(u_image0, v_texCoord + offset);
float weight;
if (u_int0 == BLUR_GAUSSIAN) {
weight = gaussian(float(i), sigma);
} else {
// BLUR_BOX
weight = 1.0;
}
color += sample_color * weight;
totalWeight += weight;
}
fragColor0 = color / totalWeight;
}

View File

@@ -1,19 +0,0 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
layout(location = 1) out vec4 fragColor1;
layout(location = 2) out vec4 fragColor2;
layout(location = 3) out vec4 fragColor3;
void main() {
vec4 color = texture(u_image0, v_texCoord);
// Output each channel as grayscale to separate render targets
fragColor0 = vec4(vec3(color.r), 1.0); // Red channel
fragColor1 = vec4(vec3(color.g), 1.0); // Green channel
fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel
fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel
}

View File

@@ -1,71 +0,0 @@
#version 300 es
precision highp float;
// Levels Adjustment
// u_int0: channel (0=RGB, 1=R, 2=G, 3=B) default: 0
// u_float0: input black (0-255) default: 0
// u_float1: input white (0-255) default: 255
// u_float2: gamma (0.01-9.99) default: 1.0
// u_float3: output black (0-255) default: 0
// u_float4: output white (0-255) default: 255
uniform sampler2D u_image0;
uniform int u_int0;
uniform float u_float0;
uniform float u_float1;
uniform float u_float2;
uniform float u_float3;
uniform float u_float4;
in vec2 v_texCoord;
out vec4 fragColor;
vec3 applyLevels(vec3 color, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
float inRange = max(inWhite - inBlack, 0.0001);
vec3 result = clamp((color - inBlack) / inRange, 0.0, 1.0);
result = pow(result, vec3(1.0 / gamma));
result = mix(vec3(outBlack), vec3(outWhite), result);
return result;
}
float applySingleChannel(float value, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
float inRange = max(inWhite - inBlack, 0.0001);
float result = clamp((value - inBlack) / inRange, 0.0, 1.0);
result = pow(result, 1.0 / gamma);
result = mix(outBlack, outWhite, result);
return result;
}
void main() {
vec4 texColor = texture(u_image0, v_texCoord);
vec3 color = texColor.rgb;
float inBlack = u_float0 / 255.0;
float inWhite = u_float1 / 255.0;
float gamma = u_float2;
float outBlack = u_float3 / 255.0;
float outWhite = u_float4 / 255.0;
vec3 result;
if (u_int0 == 0) {
result = applyLevels(color, inBlack, inWhite, gamma, outBlack, outWhite);
}
else if (u_int0 == 1) {
result = color;
result.r = applySingleChannel(color.r, inBlack, inWhite, gamma, outBlack, outWhite);
}
else if (u_int0 == 2) {
result = color;
result.g = applySingleChannel(color.g, inBlack, inWhite, gamma, outBlack, outWhite);
}
else if (u_int0 == 3) {
result = color;
result.b = applySingleChannel(color.b, inBlack, inWhite, gamma, outBlack, outWhite);
}
else {
result = color;
}
fragColor = vec4(result, texColor.a);
}

View File

@@ -1,28 +0,0 @@
# GLSL Shader Sources
This folder contains the GLSL fragment shaders extracted from blueprint JSON files for easier editing and version control.
## File Naming Convention
`{Blueprint_Name}_{node_id}.frag`
- **Blueprint_Name**: The JSON filename with spaces/special chars replaced by underscores
- **node_id**: The GLSLShader node ID within the subgraph
## Usage
```bash
# Extract shaders from blueprint JSONs to this folder
python update_blueprints.py extract
# Patch edited shaders back into blueprint JSONs
python update_blueprints.py patch
```
## Workflow
1. Run `extract` to pull current shaders from JSONs
2. Edit `.frag` files
3. Run `patch` to update the blueprint JSONs
4. Test
5. Commit both `.frag` files and updated JSONs

View File

@@ -1,28 +0,0 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform float u_float0; // strength [0.0 2.0] typical: 0.31.0
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
void main() {
vec2 texel = 1.0 / u_resolution;
// Sample center and neighbors
vec4 center = texture(u_image0, v_texCoord);
vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));
vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));
vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));
vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));
// Edge enhancement (Laplacian)
vec4 edges = center * 4.0 - top - bottom - left - right;
// Add edges back scaled by strength
vec4 sharpened = center + edges * u_float0;
fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);
}

View File

@@ -1,61 +0,0 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5
uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels
uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
float gaussian(float x, float sigma) {
return exp(-(x * x) / (2.0 * sigma * sigma));
}
float getLuminance(vec3 color) {
return dot(color, vec3(0.2126, 0.7152, 0.0722));
}
void main() {
vec2 texel = 1.0 / u_resolution;
float radius = max(u_float1, 0.5);
float amount = u_float0;
float threshold = u_float2;
vec4 original = texture(u_image0, v_texCoord);
// Gaussian blur for the "unsharp" mask
int samples = int(ceil(radius));
float sigma = radius / 2.0;
vec4 blurred = vec4(0.0);
float totalWeight = 0.0;
for (int x = -samples; x <= samples; x++) {
for (int y = -samples; y <= samples; y++) {
vec2 offset = vec2(float(x), float(y)) * texel;
vec4 sample_color = texture(u_image0, v_texCoord + offset);
float dist = length(vec2(float(x), float(y)));
float weight = gaussian(dist, sigma);
blurred += sample_color * weight;
totalWeight += weight;
}
}
blurred /= totalWeight;
// Unsharp mask = original - blurred
vec3 mask = original.rgb - blurred.rgb;
// Luminance-based threshold with smooth falloff
float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));
float thresholdScale = smoothstep(0.0, threshold, lumaDelta);
mask *= thresholdScale;
// Sharpen: original + mask * amount
vec3 sharpened = original.rgb + mask * amount;
fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);
}

View File

@@ -1,159 +0,0 @@
#!/usr/bin/env python3
"""
Shader Blueprint Updater
Syncs GLSL shader files between this folder and blueprint JSON files.
File naming convention:
{Blueprint Name}_{node_id}.frag
Usage:
python update_blueprints.py extract # Extract shaders from JSONs to here
python update_blueprints.py patch # Patch shaders back into JSONs
python update_blueprints.py # Same as patch (default)
"""
import json
import logging
import sys
import re
from pathlib import Path
logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)
GLSL_DIR = Path(__file__).parent
BLUEPRINTS_DIR = GLSL_DIR.parent
def get_blueprint_files():
"""Get all blueprint JSON files."""
return sorted(BLUEPRINTS_DIR.glob("*.json"))
def sanitize_filename(name):
"""Convert blueprint name to safe filename."""
return re.sub(r'[^\w\-]', '_', name)
def extract_shaders():
"""Extract all shaders from blueprint JSONs to this folder."""
extracted = 0
for json_path in get_blueprint_files():
blueprint_name = json_path.stem
try:
with open(json_path, 'r') as f:
data = json.load(f)
except (json.JSONDecodeError, IOError) as e:
logger.warning("Skipping %s: %s", json_path.name, e)
continue
# Find GLSLShader nodes in subgraphs
for subgraph in data.get('definitions', {}).get('subgraphs', []):
for node in subgraph.get('nodes', []):
if node.get('type') == 'GLSLShader':
node_id = node.get('id')
widgets = node.get('widgets_values', [])
# Find shader code (first string that looks like GLSL)
for widget in widgets:
if isinstance(widget, str) and widget.startswith('#version'):
safe_name = sanitize_filename(blueprint_name)
frag_name = f"{safe_name}_{node_id}.frag"
frag_path = GLSL_DIR / frag_name
with open(frag_path, 'w') as f:
f.write(widget)
logger.info(" Extracted: %s", frag_name)
extracted += 1
break
logger.info("\nExtracted %d shader(s)", extracted)
def patch_shaders():
"""Patch shaders from this folder back into blueprint JSONs."""
# Build lookup: blueprint_name -> [(node_id, shader_code), ...]
shader_updates = {}
for frag_path in sorted(GLSL_DIR.glob("*.frag")):
# Parse filename: {blueprint_name}_{node_id}.frag
parts = frag_path.stem.rsplit('_', 1)
if len(parts) != 2:
logger.warning("Skipping %s: invalid filename format", frag_path.name)
continue
blueprint_name, node_id_str = parts
try:
node_id = int(node_id_str)
except ValueError:
logger.warning("Skipping %s: invalid node_id", frag_path.name)
continue
with open(frag_path, 'r') as f:
shader_code = f.read()
if blueprint_name not in shader_updates:
shader_updates[blueprint_name] = []
shader_updates[blueprint_name].append((node_id, shader_code))
# Apply updates to JSON files
patched = 0
for json_path in get_blueprint_files():
blueprint_name = sanitize_filename(json_path.stem)
if blueprint_name not in shader_updates:
continue
try:
with open(json_path, 'r') as f:
data = json.load(f)
except (json.JSONDecodeError, IOError) as e:
logger.error("Error reading %s: %s", json_path.name, e)
continue
modified = False
for node_id, shader_code in shader_updates[blueprint_name]:
# Find the node and update
for subgraph in data.get('definitions', {}).get('subgraphs', []):
for node in subgraph.get('nodes', []):
if node.get('id') == node_id and node.get('type') == 'GLSLShader':
widgets = node.get('widgets_values', [])
if len(widgets) > 0 and widgets[0] != shader_code:
widgets[0] = shader_code
modified = True
logger.info(" Patched: %s (node %d)", json_path.name, node_id)
patched += 1
if modified:
with open(json_path, 'w') as f:
json.dump(data, f)
if patched == 0:
logger.info("No changes to apply.")
else:
logger.info("\nPatched %d shader(s)", patched)
def main():
if len(sys.argv) < 2:
command = "patch"
else:
command = sys.argv[1].lower()
if command == "extract":
logger.info("Extracting shaders from blueprints...")
extract_shaders()
elif command in ("patch", "update", "apply"):
logger.info("Patching shaders into blueprints...")
patch_shaders()
else:
logger.info(__doc__)
sys.exit(1)
if __name__ == "__main__":
main()

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1 +0,0 @@
{"revision": 0, "last_node_id": 29, "last_link_id": 0, "nodes": [{"id": 29, "type": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "pos": [1970, -230], "size": [180, 86], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": []}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": []}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": []}], "title": "Image Channels", "properties": {"proxyWidgets": []}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 28, "lastLinkId": 39, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image Channels", "inputNode": {"id": -10, "bounding": [1820, -185, 120, 60]}, "outputNode": {"id": -20, "bounding": [2460, -215, 120, 120]}, "inputs": [{"id": "3522932b-2d86-4a1f-a02a-cb29f3a9d7fe", "name": "images.image0", "type": "IMAGE", "linkIds": [39], "localized_name": "images.image0", "label": "image", "pos": [1920, -165]}], "outputs": [{"id": "605cb9c3-b065-4d9b-81d2-3ec331889b2b", "name": "IMAGE0", "type": "IMAGE", "linkIds": [26], "localized_name": "IMAGE0", "label": "R", "pos": [2480, -195]}, {"id": "fb44a77e-0522-43e9-9527-82e7465b3596", "name": "IMAGE1", "type": "IMAGE", "linkIds": [27], "localized_name": "IMAGE1", "label": "G", "pos": [2480, -175]}, {"id": "81460ee6-0131-402a-874f-6bf3001fc4ff", "name": "IMAGE2", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE2", "label": "B", "pos": [2480, -155]}, {"id": "ae690246-80d4-4951-b1d9-9306d8a77417", "name": "IMAGE3", "type": "IMAGE", "linkIds": [29], "localized_name": "IMAGE3", "label": "A", "pos": [2480, -135]}], "widgets": [], "nodes": [{"id": 23, "type": "GLSLShader", "pos": [2000, -330], "size": [400, 172], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 39}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [26]}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": [27]}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": [28]}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": [29]}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\nlayout(location = 1) out vec4 fragColor1;\nlayout(location = 2) out vec4 fragColor2;\nlayout(location = 3) out vec4 fragColor3;\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n // Output each channel as grayscale to separate render targets\n fragColor0 = vec4(vec3(color.r), 1.0); // Red channel\n fragColor1 = vec4(vec3(color.g), 1.0); // Green channel\n fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel\n fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel\n}\n", "from_input"]}], "groups": [], "links": [{"id": 39, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 26, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 27, "origin_id": 23, "origin_slot": 1, "target_id": -20, "target_slot": 1, "type": "IMAGE"}, {"id": 28, "origin_id": 23, "origin_slot": 2, "target_id": -20, "target_slot": 2, "type": "IMAGE"}, {"id": 29, "origin_id": 23, "origin_slot": 3, "target_id": -20, "target_slot": 3, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1 +0,0 @@
{"revision": 0, "last_node_id": 15, "last_link_id": 0, "nodes": [{"id": 15, "type": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "pos": [-1490, 2040], "size": [400, 260], "flags": {}, "order": 0, "mode": 0, "inputs": [{"name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": null}, {"label": "reference images", "name": "images", "type": "IMAGE", "link": null}], "outputs": [{"name": "STRING", "type": "STRING", "links": null}], "title": "Prompt Enhance", "properties": {"proxyWidgets": [["-1", "prompt"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": [""]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 15, "lastLinkId": 14, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Prompt Enhance", "inputNode": {"id": -10, "bounding": [-2170, 2110, 138.876953125, 80]}, "outputNode": {"id": -20, "bounding": [-640, 2110, 120, 60]}, "inputs": [{"id": "aeab7216-00e0-4528-a09b-bba50845c5a6", "name": "prompt", "type": "STRING", "linkIds": [11], "pos": [-2051.123046875, 2130]}, {"id": "7b73fd36-aa31-4771-9066-f6c83879994b", "name": "images", "type": "IMAGE", "linkIds": [14], "label": "reference images", "pos": [-2051.123046875, 2150]}], "outputs": [{"id": "c7b0d930-68a1-48d1-b496-0519e5837064", "name": "STRING", "type": "STRING", "linkIds": [13], "pos": [-620, 2130]}], "widgets": [], "nodes": [{"id": 11, "type": "GeminiNode", "pos": [-1560, 1990], "size": [470, 470], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "shape": 7, "type": "IMAGE", "link": 14}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": null}, {"localized_name": "video", "name": "video", "shape": 7, "type": "VIDEO", "link": null}, {"localized_name": "files", "name": "files", "shape": 7, "type": "GEMINI_INPUT_FILES", "link": null}, {"localized_name": "prompt", "name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": 11}, {"localized_name": "model", "name": "model", "type": "COMBO", "widget": {"name": "model"}, "link": null}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "system_prompt", "name": "system_prompt", "shape": 7, "type": "STRING", "widget": {"name": "system_prompt"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.14.1", "Node name for S&R": "GeminiNode"}, "widgets_values": ["", "gemini-3-pro-preview", 42, "randomize", "You are an expert in prompt writing.\nBased on the input, rewrite the user's input into a detailed prompt.\nincluding camera settings, lighting, composition, and style.\nReturn the prompt only"], "color": "#432", "bgcolor": "#653"}], "groups": [], "links": [{"id": 11, "origin_id": -10, "origin_slot": 0, "target_id": 11, "target_slot": 4, "type": "STRING"}, {"id": 13, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "STRING"}, {"id": 14, "origin_id": -10, "origin_slot": 1, "target_id": 11, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Text generation/Prompt enhance"}]}, "extra": {}}

View File

@@ -1 +0,0 @@
{"revision": 0, "last_node_id": 25, "last_link_id": 0, "nodes": [{"id": 25, "type": "621ba4e2-22a8-482d-a369-023753198b7b", "pos": [4610, -790], "size": [230, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "title": "Sharpen", "properties": {"proxyWidgets": [["24", "value"]]}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "621ba4e2-22a8-482d-a369-023753198b7b", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 24, "lastLinkId": 36, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Sharpen", "inputNode": {"id": -10, "bounding": [4090, -825, 120, 60]}, "outputNode": {"id": -20, "bounding": [5150, -825, 120, 60]}, "inputs": [{"id": "37011fb7-14b7-4e0e-b1a0-6a02e8da1fd7", "name": "images.image0", "type": "IMAGE", "linkIds": [34], "localized_name": "images.image0", "label": "image", "pos": [4190, -805]}], "outputs": [{"id": "e9182b3f-635c-4cd4-a152-4b4be17ae4b9", "name": "IMAGE0", "type": "IMAGE", "linkIds": [35], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [5170, -805]}], "widgets": [], "nodes": [{"id": 24, "type": "PrimitiveFloat", "pos": [4280, -1240], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "strength", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [36]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 3, "precision": 2, "step": 0.05}, "widgets_values": [0.5]}, {"id": 23, "type": "GLSLShader", "pos": [4570, -1240], "size": [370, 192], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 34}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 36}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [35]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 2.0] typical: 0.31.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}", "from_input"]}], "groups": [], "links": [{"id": 36, "origin_id": 24, "origin_slot": 0, "target_id": 23, "target_slot": 2, "type": "FLOAT"}, {"id": 34, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 35, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Sharpen"}]}}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1 +0,0 @@
{"revision": 0, "last_node_id": 13, "last_link_id": 0, "nodes": [{"id": 13, "type": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "pos": [1120, 330], "size": [240, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": null}, {"name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": []}], "title": "Video Upscale(GAN x4)", "properties": {"proxyWidgets": [["-1", "model_name"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 13, "lastLinkId": 19, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Video Upscale(GAN x4)", "inputNode": {"id": -10, "bounding": [550, 460, 120, 80]}, "outputNode": {"id": -20, "bounding": [1490, 460, 120, 60]}, "inputs": [{"id": "666d633e-93e7-42dc-8d11-2b7b99b0f2a6", "name": "video", "type": "VIDEO", "linkIds": [10], "localized_name": "video", "pos": [650, 480]}, {"id": "2e23a087-caa8-4d65-99e6-662761aa905a", "name": "model_name", "type": "COMBO", "linkIds": [19], "pos": [650, 500]}], "outputs": [{"id": "0c1768ea-3ec2-412f-9af6-8e0fa36dae70", "name": "VIDEO", "type": "VIDEO", "linkIds": [15], "localized_name": "VIDEO", "pos": [1510, 480]}], "widgets": [], "nodes": [{"id": 2, "type": "ImageUpscaleWithModel", "pos": [1110, 450], "size": [320, 46], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "upscale_model", "name": "upscale_model", "type": "UPSCALE_MODEL", "link": 1}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 14}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "ImageUpscaleWithModel"}}, {"id": 11, "type": "CreateVideo", "pos": [1110, 550], "size": [320, 78], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 13}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": 16}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": 12}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [15]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "CreateVideo"}, "widgets_values": [30]}, {"id": 10, "type": "GetVideoComponents", "pos": [1110, 330], "size": [320, 70], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": 10}], "outputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "links": [14]}, {"localized_name": "audio", "name": "audio", "type": "AUDIO", "links": [16]}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "links": [12]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "GetVideoComponents"}}, {"id": 1, "type": "UpscaleModelLoader", "pos": [750, 450], "size": [280, 60], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "model_name", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": 19}], "outputs": [{"localized_name": "UPSCALE_MODEL", "name": "UPSCALE_MODEL", "type": "UPSCALE_MODEL", "links": [1]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "UpscaleModelLoader", "models": [{"name": "RealESRGAN_x4plus.safetensors", "url": "https://huggingface.co/Comfy-Org/Real-ESRGAN_repackaged/resolve/main/RealESRGAN_x4plus.safetensors", "directory": "upscale_models"}]}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "groups": [], "links": [{"id": 1, "origin_id": 1, "origin_slot": 0, "target_id": 2, "target_slot": 0, "type": "UPSCALE_MODEL"}, {"id": 14, "origin_id": 10, "origin_slot": 0, "target_id": 2, "target_slot": 1, "type": "IMAGE"}, {"id": 13, "origin_id": 2, "origin_slot": 0, "target_id": 11, "target_slot": 0, "type": "IMAGE"}, {"id": 16, "origin_id": 10, "origin_slot": 1, "target_id": 11, "target_slot": 1, "type": "AUDIO"}, {"id": 12, "origin_id": 10, "origin_slot": 2, "target_id": 11, "target_slot": 2, "type": "FLOAT"}, {"id": 10, "origin_id": -10, "origin_slot": 0, "target_id": 10, "target_slot": 0, "type": "VIDEO"}, {"id": 15, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}, {"id": 19, "origin_id": -10, "origin_slot": 1, "target_id": 1, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video generation and editing/Enhance video"}]}, "extra": {}}

View File

@@ -0,0 +1,13 @@
import pickle
load = pickle.load
class Empty:
pass
class Unpickler(pickle.Unpickler):
def find_class(self, module, name):
#TODO: safe unpickle
if module.startswith("pytorch_lightning"):
return Empty
return super().find_class(module, name)

View File

@@ -176,8 +176,6 @@ class InputTypeOptions(TypedDict):
"""COMBO type only. Specifies the configuration for a multi-select widget.
Available after ComfyUI frontend v1.13.4
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
gradient_stops: NotRequired[list[list[float]]]
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
class HiddenInputTypeDict(TypedDict):

View File

@@ -297,30 +297,6 @@ class ControlNet(ControlBase):
self.model_sampling_current = None
super().cleanup()
class QwenFunControlNet(ControlNet):
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
# Fun checkpoints are more sensitive to high strengths in the generic
# ControlNet merge path. Use a soft response curve so strength=1.0 stays
# unchanged while >1 grows more gently.
original_strength = self.strength
self.strength = math.sqrt(max(self.strength, 0.0))
try:
return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
finally:
self.strength = original_strength
def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function)
self.set_extra_arg("base_model", model.diffusion_model)
def copy(self):
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
return c
class ControlLoraOps:
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
@@ -584,7 +560,6 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
sd = model_config.process_unet_state_dict(sd)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
@@ -630,53 +605,6 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def load_controlnet_qwen_fun(sd, model_options={}):
load_device = comfy.model_management.get_torch_device()
weight_dtype = comfy.utils.weight_dtype(sd)
unet_dtype = model_options.get("dtype", weight_dtype)
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
in_features = sd["control_img_in.weight"].shape[1]
inner_dim = sd["control_img_in.weight"].shape[0]
block_weight = sd["control_blocks.0.attn.to_q.weight"]
attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
control_in_features=in_features,
inner_dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_control_blocks=5,
main_model_double=60,
injection_layers=(0, 12, 24, 36, 48),
operations=operations,
device=comfy.model_management.unet_offload_device(),
dtype=unet_dtype,
)
model = controlnet_load_state_dict(model, sd)
latent_format = comfy.latent_formats.Wan21()
control = QwenFunControlNet(
model,
compression_ratio=1,
latent_format=latent_format,
# Fun checkpoints already expect their own 33-channel context handling.
# Enabling generic concat_mask injects an extra mask channel at apply-time
# and breaks the intended fallback packing path.
concat_mask=False,
load_device=load_device,
manual_cast_dtype=manual_cast_dtype,
extra_conds=[],
)
return control
def convert_mistoline(sd):
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
@@ -754,8 +682,6 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)

View File

@@ -179,8 +179,8 @@ class LLMAdapter(nn.Module):
if source_attention_mask.ndim == 2:
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
x = self.in_proj(self.embed(target_input_ids))
context = source_hidden_states
x = self.in_proj(self.embed(target_input_ids, out_dtype=context.dtype))
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
position_embeddings = self.rotary_emb(x, position_ids)

View File

@@ -3,6 +3,7 @@ from torch import Tensor, nn
from comfy.ldm.flux.layers import (
MLPEmbedder,
RMSNorm,
ModulationOut,
)
@@ -28,7 +29,7 @@ class Approximator(nn.Module):
super().__init__()
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
@property

View File

@@ -152,7 +152,6 @@ class Chroma(nn.Module):
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
transformer_options = transformer_options.copy()
patches_replace = transformer_options.get("patches_replace", {})
# running on sequences img
@@ -229,7 +228,6 @@ class Chroma(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if i not in self.skip_dit:

View File

@@ -4,6 +4,8 @@ from functools import lru_cache
import torch
from torch import nn
from comfy.ldm.flux.layers import RMSNorm
class NerfEmbedder(nn.Module):
"""
@@ -143,7 +145,7 @@ class NerfGLUBlock(nn.Module):
# We now need to generate parameters for 3 matrices.
total_params = 3 * hidden_size_x**2 * mlp_ratio
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
self.mlp_ratio = mlp_ratio
@@ -176,7 +178,7 @@ class NerfGLUBlock(nn.Module):
class NerfFinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -188,7 +190,7 @@ class NerfFinalLayer(nn.Module):
class NerfFinalLayerConv(nn.Module):
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
self.conv = operations.Conv2d(
in_channels=hidden_size,
out_channels=out_channels,

View File

@@ -5,9 +5,9 @@ import torch
from torch import Tensor, nn
from .math import attention, rope
import comfy.ops
import comfy.ldm.common_dit
# Fix import for some custom nodes, TODO: delete eventually.
RMSNorm = None
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list):
@@ -87,12 +87,20 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
def forward(self, x: Tensor):
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
class QKNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
q = self.query_norm(q)
@@ -161,7 +169,7 @@ class SiLUActivation(nn.Module):
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
@@ -189,6 +197,8 @@ class DoubleStreamBlock(nn.Module):
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
if self.modulation:
img_mod1, img_mod2 = self.img_mod(vec)
@@ -196,9 +206,6 @@ class DoubleStreamBlock(nn.Module):
else:
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
transformer_patches = transformer_options.get("patches", {})
extra_options = transformer_options.copy()
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
@@ -217,23 +224,32 @@ class DoubleStreamBlock(nn.Module):
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# run actual attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
if self.flipped_img_txt:
q = torch.cat((img_q, txt_q), dim=2)
del img_q, txt_q
k = torch.cat((img_k, txt_k), dim=2)
del img_k, txt_k
v = torch.cat((img_v, txt_v), dim=2)
del img_v, txt_v
# run actual attention
attn = attention(q, k, v,
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
if "attn1_output_patch" in transformer_patches:
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
patch = transformer_patches["attn1_output_patch"]
for p in patch:
attn = p(attn, extra_options)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# run actual attention
attn = attention(q, k, v,
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
@@ -312,9 +328,6 @@ class SingleStreamBlock(nn.Module):
else:
mod = vec
transformer_patches = transformer_options.get("patches", {})
extra_options = transformer_options.copy()
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
@@ -324,12 +337,6 @@ class SingleStreamBlock(nn.Module):
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
if "attn1_output_patch" in transformer_patches:
patch = transformer_patches["attn1_output_patch"]
for p in patch:
attn = p(attn, extra_options)
# compute activation in mlp stream, cat again and run second linear layer
if self.yak_mlp:
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]

View File

@@ -16,6 +16,7 @@ from .layers import (
SingleStreamBlock,
timestep_embedding,
Modulation,
RMSNorm
)
@dataclass
@@ -80,7 +81,7 @@ class Flux(nn.Module):
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
if params.txt_norm:
self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
else:
self.txt_norm = None
@@ -142,7 +143,6 @@ class Flux(nn.Module):
attn_mask: Tensor = None,
) -> Tensor:
transformer_options = transformer_options.copy()
patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
@@ -232,7 +232,6 @@ class Flux(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:

View File

@@ -241,6 +241,7 @@ class HunyuanVideo(nn.Module):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
flipped_img_txt=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@@ -304,7 +305,6 @@ class HunyuanVideo(nn.Module):
control=None,
transformer_options={},
) -> Tensor:
transformer_options = transformer_options.copy()
patches_replace = transformer_options.get("patches_replace", {})
initial_shape = list(img.shape)
@@ -378,14 +378,14 @@ class HunyuanVideo(nn.Module):
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
ids = torch.cat((txt_ids, img_ids), dim=1)
ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)
img_len = img.shape[1]
if txt_mask is not None:
attn_mask_len = img_len + txt.shape[1]
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
attn_mask[:, 0, :txt.shape[1]] = txt_mask
attn_mask[:, 0, img_len:] = txt_mask
else:
attn_mask = None
@@ -413,11 +413,10 @@ class HunyuanVideo(nn.Module):
if add is not None:
img += add
img = torch.cat((txt, img), 1)
img = torch.cat((img, txt), 1)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:
@@ -436,9 +435,9 @@ class HunyuanVideo(nn.Module):
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1]: img_len + txt.shape[1]] += add
img[:, : img_len] += add
img = img[:, txt.shape[1]: img_len + txt.shape[1]]
img = img[:, : img_len]
if ref_latent is not None:
img = img[:, ref_latent.shape[1]:]

View File

@@ -9,7 +9,6 @@ from comfy.ldm.lightricks.model import (
LTXVModel,
)
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import comfy.ldm.common_dit
class CompressedTimestep:
@@ -451,29 +450,6 @@ class LTXAVModel(LTXVModel):
operations=self.operations,
)
self.audio_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=self.operations,
)
self.video_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=self.operations,
)
def preprocess_text_embeds(self, context):
if context.shape[-1] == self.caption_channels * 2:
return context
out_vid = self.video_embeddings_connector(context)[0]
out_audio = self.audio_embeddings_connector(context)[0]
return torch.concat((out_vid, out_audio), dim=-1)
def _init_transformer_blocks(self, device, dtype, **kwargs):
"""Initialize transformer blocks for LTXAV."""
self.transformer_blocks = nn.ModuleList(

View File

@@ -157,9 +157,11 @@ class Embeddings1DConnector(nn.Module):
self.num_learnable_registers = num_learnable_registers
if self.num_learnable_registers:
self.learnable_registers = nn.Parameter(
torch.empty(
torch.rand(
self.num_learnable_registers, inner_dim, dtype=dtype, device=device
)
* 2.0
- 1.0
)
def get_fractional_positions(self, indices_grid):
@@ -232,7 +234,7 @@ class Embeddings1DConnector(nn.Module):
return indices
def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None):
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
dim = self.inner_dim
n_elem = 2 # 2 because of cos and sin
freqs = self.precompute_freqs(indices_grid, spacing)
@@ -245,7 +247,7 @@ class Embeddings1DConnector(nn.Module):
)
else:
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
def forward(
self,
@@ -286,7 +288,7 @@ class Embeddings1DConnector(nn.Module):
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
)
indices_grid = indices_grid[None, None, :]
freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype)
freqs_cis = self.precompute_freqs_cis(indices_grid)
# 2. Blocks
for block_idx, block in enumerate(self.transformer_1d_blocks):

View File

@@ -102,7 +102,19 @@ class VideoConv3d(nn.Module):
return self.conv(x)
def interpolate_up(x, scale_factor):
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
try:
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
except: #operation not implemented for bf16
orig_shape = list(x.shape)
out_shape = orig_shape[:2]
for i in range(len(orig_shape) - 2):
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
return out
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):

View File

@@ -2,196 +2,6 @@ import torch
import math
from .model import QwenImageTransformer2DModel
from .model import QwenImageTransformerBlock
class QwenImageFunControlBlock(QwenImageTransformerBlock):
def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None):
super().__init__(
dim=dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dtype=dtype,
device=device,
operations=operations,
)
self.has_before_proj = has_before_proj
if has_before_proj:
self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
class QwenImageFunControlNetModel(torch.nn.Module):
def __init__(
self,
control_in_features=132,
inner_dim=3072,
num_attention_heads=24,
attention_head_dim=128,
num_control_blocks=5,
main_model_double=60,
injection_layers=(0, 12, 24, 36, 48),
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dtype = dtype
self.main_model_double = main_model_double
self.injection_layers = tuple(injection_layers)
# Keep base hint scaling at 1.0 so user-facing strength behaves similarly
# to the reference Gen2/VideoX implementation around strength=1.
self.hint_scale = 1.0
self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype)
self.control_blocks = torch.nn.ModuleList([])
for i in range(num_control_blocks):
self.control_blocks.append(
QwenImageFunControlBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
has_before_proj=(i == 0),
dtype=dtype,
device=device,
operations=operations,
)
)
def _process_hint_tokens(self, hint):
if hint is None:
return None
if hint.ndim == 4:
hint = hint.unsqueeze(2)
# Fun checkpoints are trained with 33 latent channels before 2x2 packing:
# [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features.
# Default behavior (no inpaint input in stock Apply ControlNet) should use
# zeros for mask/inpaint branches, matching VideoX fallback semantics.
expected_c = self.control_img_in.weight.shape[1] // 4
if hint.shape[1] == 16 and expected_c == 33:
zeros_mask = torch.zeros_like(hint[:, :1])
zeros_inpaint = torch.zeros_like(hint)
hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1)
bs, c, t, h, w = hint.shape
hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(
orig_shape[0],
orig_shape[1],
orig_shape[-3],
orig_shape[-2] // 2,
2,
orig_shape[-1] // 2,
2,
)
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
hidden_states = hidden_states.reshape(
bs,
t * ((h + 1) // 2) * ((w + 1) // 2),
c * 4,
)
expected_in = self.control_img_in.weight.shape[1]
cur_in = hidden_states.shape[-1]
if cur_in < expected_in:
pad = torch.zeros(
(hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
hidden_states = torch.cat([hidden_states, pad], dim=-1)
elif cur_in > expected_in:
hidden_states = hidden_states[:, :, :expected_in]
return hidden_states
def forward(
self,
x,
timesteps,
context,
attention_mask=None,
guidance: torch.Tensor = None,
hint=None,
transformer_options={},
base_model=None,
**kwargs,
):
if base_model is None:
raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.")
encoder_hidden_states_mask = attention_mask
# Keep attention mask disabled inside Fun control blocks to mirror
# VideoX behavior (they rely on seq lengths for RoPE, not masked attention).
encoder_hidden_states_mask = None
hidden_states, img_ids, _ = base_model.process_img(x)
hint_tokens = self._process_hint_tokens(hint)
if hint_tokens is None:
raise RuntimeError("Qwen Fun ControlNet requires a control hint image.")
if hint_tokens.shape[1] != hidden_states.shape[1]:
max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1])
hint_tokens = hint_tokens[:, :max_tokens]
hidden_states = hidden_states[:, :max_tokens]
img_ids = img_ids[:, :max_tokens]
txt_start = round(
max(
((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
)
)
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous()
hidden_states = base_model.img_in(hidden_states)
encoder_hidden_states = base_model.txt_norm(context)
encoder_hidden_states = base_model.txt_in(encoder_hidden_states)
if guidance is not None:
guidance = guidance * 1000
temb = (
base_model.time_text_embed(timesteps, hidden_states)
if guidance is None
else base_model.time_text_embed(timesteps, guidance, hidden_states)
)
c = self.control_img_in(hint_tokens)
for i, block in enumerate(self.control_blocks):
if i == 0:
c_in = block.before_proj(c) + hidden_states
all_c = []
else:
all_c = list(torch.unbind(c, dim=0))
c_in = all_c.pop(-1)
encoder_hidden_states, c_out = block(
hidden_states=c_in,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
c_skip = block.after_proj(c_out) * self.hint_scale
all_c += [c_skip, c_out]
c = torch.stack(all_c, dim=0)
hints = torch.unbind(c, dim=0)[:-1]
controlnet_block_samples = [None] * self.main_model_double
for local_idx, base_idx in enumerate(self.injection_layers):
if local_idx < len(hints) and base_idx < len(controlnet_block_samples):
controlnet_block_samples[base_idx] = hints[local_idx]
return {"input": controlnet_block_samples}
class QwenImageControlNetModel(QwenImageTransformer2DModel):

View File

@@ -374,31 +374,6 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
return padded_tensor
def calculate_shape(patches, weight, key, original_weights=None):
current_shape = weight.shape
for p in patches:
v = p[1]
offset = p[3]
# Offsets restore the old shape; lists force a diff without metadata
if offset is not None or isinstance(v, list):
continue
if isinstance(v, weight_adapter.WeightAdapterBase):
adapter_shape = v.calculate_shape(key)
if adapter_shape is not None:
current_shape = adapter_shape
continue
# Standard diff logic with padding
if len(v) == 2:
patch_type, patch_data = v[0], v[1]
if patch_type == "diff" and len(patch_data) > 1 and patch_data[1]['pad_weight']:
current_shape = patch_data[0].shape
return current_shape
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
for p in patches:
strength = p[0]

View File

@@ -5,7 +5,7 @@ import comfy.utils
def convert_lora_bfl_control(sd): #BFL loras for Flux
sd_out = {}
for k in sd:
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight"))
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
sd_out[k_to] = sd[k]
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])

View File

@@ -78,4 +78,4 @@ def interpret_gathered_like(tensors, gathered):
return dest_views
aimdo_enabled = False
aimdo_allocator = None

View File

@@ -178,7 +178,10 @@ class BaseModel(torch.nn.Module):
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
context = c_crossattn
dtype = self.get_dtype_inference()
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
xc = xc.to(dtype)
device = xc.device
@@ -215,13 +218,6 @@ class BaseModel(torch.nn.Module):
def get_dtype(self):
return self.diffusion_model.dtype
def get_dtype_inference(self):
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
return dtype
def encode_adm(self, **kwargs):
return None
@@ -376,7 +372,9 @@ class BaseModel(torch.nn.Module):
input_shapes += shape
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype_inference()
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
#TODO: this needs to be tweaked
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
@@ -988,14 +986,10 @@ class LTXAV(BaseModel):
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
device = kwargs["device"]
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
@@ -1171,7 +1165,7 @@ class Anima(BaseModel):
t5xxl_ids = t5xxl_ids.unsqueeze(0)
if torch.is_inference_mode_enabled(): # if not we are training
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference()))
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
else:
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)

View File

@@ -19,12 +19,6 @@ def count_blocks(state_dict_keys, prefix_string):
count += 1
return count
def any_suffix_in(keys, prefix, main, suffix_list=[]):
for x in suffix_list:
if "{}{}{}".format(prefix, main, x) in keys:
return True
return False
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
context_dim = None
use_linear_in_transformer = False
@@ -192,7 +186,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["meanflow_sum"] = False
return dit_config
if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
dit_config = {}
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
dit_config["image_model"] = "flux2"
@@ -247,8 +241,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
dit_config["image_model"] = "chroma"
dit_config["in_channels"] = 64
dit_config["out_channels"] = 64
@@ -256,8 +249,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["out_dim"] = 3072
dit_config["hidden_dim"] = 5120
dit_config["n_layers"] = 5
if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
dit_config["image_model"] = "chroma_radiance"
dit_config["in_channels"] = 3
dit_config["out_channels"] = 3
@@ -267,7 +259,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_depth"] = 4
dit_config["nerf_max_freqs"] = 8
dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear"
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
dit_config["use_x0"] = True
@@ -276,7 +268,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
dit_config["txt_ids_dims"] = [1, 2]

View File

@@ -836,7 +836,7 @@ def unet_inital_load_device(parameters, dtype):
mem_dev = get_free_memory(torch_dev)
mem_cpu = get_free_memory(cpu_dev)
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled:
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None:
return torch_dev
else:
return cpu_dev
@@ -1121,6 +1121,7 @@ def get_cast_buffer(offload_stream, device, size, ref):
synchronize()
del STREAM_CAST_BUFFERS[offload_stream]
del cast_buffer
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
soft_empty_cache()
with wf_context:
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)

View File

@@ -271,7 +271,6 @@ class ModelPatcher:
self.is_clip = False
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
self.cached_patcher_init: tuple[Callable, tuple] | None = None
if not hasattr(self.model, 'model_loaded_weight_memory'):
self.model.model_loaded_weight_memory = 0
@@ -308,15 +307,8 @@ class ModelPatcher:
def get_free_memory(self, device):
return comfy.model_management.get_free_memory(device)
def clone(self, disable_dynamic=False):
class_ = self.__class__
model = self.model
if self.is_dynamic() and disable_dynamic:
class_ = ModelPatcher
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
model = temp_model_patcher.model
n = class_(model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
def clone(self):
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
@@ -370,8 +362,6 @@ class ModelPatcher:
n.is_clip = self.is_clip
n.hook_mode = self.hook_mode
n.cached_patcher_init = self.cached_patcher_init
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
callback(self, n)
return n
@@ -416,16 +406,13 @@ class ModelPatcher:
def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)
def disable_model_cfg1_optimization(self):
self.model_options["disable_cfg1_optimization"] = True
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
if disable_cfg1_optimization:
self.disable_model_cfg1_optimization()
self.model_options["disable_cfg1_optimization"] = True
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
@@ -692,19 +679,18 @@ class ModelPatcher:
for key in list(self.pinned):
self.unpin_weight(key)
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
def _load_list(self, prio_comfy_cast_weights=False):
loading = []
for n, m in self.model.named_modules():
default = False
params = { name: param for name, param in m.named_parameters(recurse=False) }
params = []
skip = False
for name, param in m.named_parameters(recurse=False):
params.append(name)
for name, param in m.named_parameters(recurse=True):
if name not in params:
default = True # default random weights in non leaf modules
skip = True # skip random weights in non leaf modules
break
if default and default_device is not None:
for param in params.values():
param.data = param.data.to(device=default_device)
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
@@ -1509,7 +1495,7 @@ class ModelPatcherDynamic(ModelPatcher):
#with pin and unpin syncrhonization which can be expensive for small weights
#with a high layer rate (e.g. autoregressive LLMs).
#prioritize the non-comfy weights (note the order reverse).
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
loading = self._load_list(prio_comfy_cast_weights=True)
loading.sort(reverse=True)
for x in loading:
@@ -1527,10 +1513,8 @@ class ModelPatcherDynamic(ModelPatcher):
weight, _, _ = get_key_weight(self.model, key)
if weight is None:
return (False, 0)
return 0
if key in self.patches:
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
return (True, 0)
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
num_patches += 1
else:
@@ -1544,13 +1528,7 @@ class ModelPatcherDynamic(ModelPatcher):
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
weight._model_dtype = model_dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
return (False, comfy.memory_management.vram_aligned_size(geometry))
def force_load_param(self, param_key, device_to):
key = key_param_name_to_key(n, param_key)
if key in self.backup:
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
self.patch_weight_to_device(key, device_to=device_to)
return comfy.memory_management.vram_aligned_size(geometry)
if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True
@@ -1558,19 +1536,13 @@ class ModelPatcherDynamic(ModelPatcher):
m.seed_key = n
set_dirty(m, dirty)
force_load, v_weight_size = setup_param(self, m, n, "weight")
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
force_load = force_load or force_load_bias
v_weight_size += v_weight_bias
v_weight_size = 0
v_weight_size += setup_param(self, m, n, "weight")
v_weight_size += setup_param(self, m, n, "bias")
if force_load:
logging.info(f"Module {n} has resizing Lora - force loading")
force_load_param(self, "weight", device_to)
force_load_param(self, "bias", device_to)
else:
if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size)
allocated_size += v_weight_size
if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size)
allocated_size += v_weight_size
else:
for param in params:
@@ -1588,8 +1560,6 @@ class ModelPatcherDynamic(ModelPatcher):
allocated_size += weight_size
vbar.set_watermark_limit(allocated_size)
move_weight_functions(m, device_to)
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
self.model.device = device_to
@@ -1609,7 +1579,7 @@ class ModelPatcherDynamic(ModelPatcher):
return 0 if vbar is None else vbar.free_memory(memory_to_free)
def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
loading = self._load_list(prio_comfy_cast_weights=True)
for x in loading:
_, _, _, _, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
@@ -1630,13 +1600,6 @@ class ModelPatcherDynamic(ModelPatcher):
if unpatch_weights:
self.partially_unload_ram(1e32)
self.partially_unload(None, 1e32)
for m in self.model.modules():
move_weight_functions(m, device_to)
keys = list(self.backup.keys())
for k in keys:
bk = self.backup[k]
comfy.utils.set_attr_param(self.model, k, bk.weight)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above

View File

@@ -19,8 +19,9 @@
import torch
import logging
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
import comfy.float
import comfy.rmsnorm
import json
import comfy.memory_management
import comfy.pinned_memory
@@ -79,7 +80,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
offload_stream = None
xfer_dest = None
@@ -170,10 +171,10 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
x = lowvram_fn(x)
if (isinstance(orig, QuantizedTensor) and
(want_requant and len(fns) == 0 or update_weight)):
(orig.dtype == dtype and len(fns) == 0 or update_weight)):
seed = comfy.utils.string_to_seed(s.seed_key)
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
if want_requant and len(fns) == 0:
if orig.dtype == dtype and len(fns) == 0:
#The layer actually wants our freshly saved QT
x = y
elif update_weight:
@@ -194,7 +195,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
return weight, bias, (offload_stream, device if signature is not None else None, None)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
@@ -212,7 +213,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"):
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
@@ -296,7 +297,7 @@ class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
super().__init__(in_features, out_features, bias, device, dtype)
return
@@ -317,7 +318,7 @@ class disable_weight_init:
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
@@ -462,7 +463,7 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class RMSNorm(torch.nn.RMSNorm, CastWeightBiasOp):
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
def reset_parameters(self):
self.bias = None
return None
@@ -474,7 +475,8 @@ class disable_weight_init:
weight = None
bias = None
offload_stream = None
x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
@@ -827,10 +829,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
else:
sd = {}
if not hasattr(self, 'weight'):
logging.warning("Warning: state dict on uninitialized op {}".format(prefix))
return sd
if self.bias is not None:
sd["{}bias".format(prefix)] = self.bias
@@ -854,8 +852,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
def forward_comfy_cast_weights(self, input, compute_dtype=None):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
@@ -885,7 +883,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
output = self.forward_comfy_cast_weights(input, compute_dtype)
# Reshape output back to 3D if input was 3D
if reshaped_3d:

View File

@@ -1,10 +1,57 @@
import torch
import comfy.model_management
import numbers
import logging
RMSNorm = None
try:
rms_norm_torch = torch.nn.functional.rms_norm
RMSNorm = torch.nn.RMSNorm
except:
rms_norm_torch = None
logging.warning("Please update pytorch to use native RMSNorm")
RMSNorm = torch.nn.RMSNorm
def rms_norm(x, weight=None, eps=1e-6):
if weight is None:
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
if weight is None:
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
else:
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
else:
return torch.nn.functional.rms_norm(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
if weight is None:
return r
else:
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
if RMSNorm is None:
class RMSNorm(torch.nn.Module):
def __init__(
self,
normalized_shape,
eps=1e-6,
elementwise_affine=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = torch.nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
else:
self.register_parameter("weight", None)
self.bias = None
def forward(self, x):
return rms_norm(x, self.weight, self.eps)

View File

@@ -423,17 +423,6 @@ class CLIP:
def get_key_patches(self):
return self.patcher.get_key_patches()
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
self.cond_stage_model.reset_clip_options()
self.load_model()
self.cond_stage_model.set_clip_options({"layer": None})
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
@@ -1193,7 +1182,6 @@ class TEModel(Enum):
JINA_CLIP_2 = 19
QWEN3_8B = 20
QWEN3_06B = 21
GEMMA_3_4B_VISION = 22
def detect_te_model(sd):
@@ -1222,10 +1210,7 @@ def detect_te_model(sd):
if 'model.layers.47.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_12B
if 'model.layers.0.self_attn.q_norm.weight' in sd:
if 'vision_model.embeddings.patch_embedding.weight' in sd:
return TEModel.GEMMA_3_4B_VISION
else:
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias']
@@ -1285,8 +1270,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else:
if "text_projection" in clip_data[i]:
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
if "lm_head.weight" in clip_data[i]:
clip_data[i]["model.lm_head.weight"] = clip_data[i].pop("lm_head.weight") # prefix missing in some models
tokenizer_data = {}
clip_target = EmptyClass()
@@ -1352,14 +1335,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_4B_VISION:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b_vision")
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_12B:
clip_target.clip = comfy.text_encoders.lt.gemma3_te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.Gemma3_12BTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
@@ -1530,24 +1505,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
if output_model:
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
return out
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
embedding_directory=embedding_directory,
model_options=model_options,
te_model_options=te_model_options,
disable_dynamic=disable_dynamic)
return model
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False):
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
clip = None
clipvision = None
vae = None
@@ -1596,8 +1561,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
if output_vae:
@@ -1648,7 +1612,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable_dynamic=False):
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
"""
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
@@ -1732,8 +1696,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
model_config.optimizations["fp8"] = True
model = model_config.get_model(new_sd, "")
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device)
if not model_management.is_device_cpu(offload_device):
model.to(offload_device)
model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic())
@@ -1742,13 +1705,12 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
logging.info("left over keys in diffusion model: {}".format(left_over))
return model_patcher
def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
def load_diffusion_model(unet_path, model_options={}):
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata, disable_dynamic=disable_dynamic)
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
if model is None:
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
return model
def load_unet(unet_path, dtype=None):

View File

@@ -171,9 +171,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def process_tokens(self, tokens, device):
end_token = self.special_tokens.get("end", None)
pad_token = self.special_tokens.get("pad", -1)
if end_token is None:
cmp_token = pad_token
cmp_token = self.special_tokens.get("pad", -1)
else:
cmp_token = end_token
@@ -187,21 +186,15 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
other_embeds = []
eos = False
index = 0
left_pad = False
for y in x:
if isinstance(y, numbers.Integral):
token = int(y)
if index == 0 and token == pad_token:
left_pad = True
if eos or (left_pad and token == pad_token):
if eos:
attention_mask.append(0)
else:
attention_mask.append(1)
left_pad = False
token = int(y)
tokens_temp += [token]
if not eos and token == cmp_token and not left_pad:
if not eos and token == cmp_token:
if end_token is None:
attention_mask[-1] = 0
eos = True
@@ -308,15 +301,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
if isinstance(tokens, dict):
tokens_only = next(iter(tokens.values())) # todo: get this better?
else:
tokens_only = tokens
tokens_only = [[t[0] for t in b] for b in tokens_only]
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
def parse_parentheses(string):
result = []
current_item = ""
@@ -573,8 +557,6 @@ class SDTokenizer:
min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length)
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
min_length = kwargs.get("min_length", min_length)
text = escape_important(text)
if kwargs.get("disable_weights", self.disable_weights):
parsed_weights = [(text, 1.0)]
@@ -674,9 +656,6 @@ class SDTokenizer:
def state_dict(self):
return {}
def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None):
if name is not None:
@@ -700,9 +679,6 @@ class SD1Tokenizer:
def state_dict(self):
return getattr(self, self.clip).state_dict()
def decode(self, token_ids, skip_special_tokens=True):
return getattr(self, self.clip).decode(token_ids, skip_special_tokens=skip_special_tokens)
class SD1CheckpointClipModel(SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
@@ -739,6 +715,3 @@ class SD1ClipModel(torch.nn.Module):
def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd)
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)

View File

@@ -710,15 +710,6 @@ class Flux(supported_models_base.BASE):
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
key_out = k
if key_out.endswith("_norm.scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
@@ -907,13 +898,11 @@ class HunyuanVideo(supported_models_base.BASE):
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight")
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight")
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
if key_out.endswith(".scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
@@ -1275,15 +1264,6 @@ class Hunyuan3Dv2(supported_models_base.BASE):
latent_format = latent_formats.Hunyuan3Dv2
def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
key_out = k
if key_out.endswith(".scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
@@ -1361,14 +1341,6 @@ class Chroma(supported_models_base.BASE):
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
key_out = k
if key_out.endswith(".scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Chroma(self, device=device)

View File

@@ -10,6 +10,7 @@ import comfy.utils
def sample_manual_loop_no_classes(
model,
ids=None,
paddings=[],
execution_dtype=None,
cfg_scale: float = 2.0,
temperature: float = 0.85,
@@ -35,6 +36,9 @@ def sample_manual_loop_no_classes(
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
embeds_batch = embeds.shape[0]
for i, t in enumerate(paddings):
attention_mask[i, :t] = 0
attention_mask[i, t:] = 1
output_audio_codes = []
past_key_values = []
@@ -131,11 +135,13 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
pos_pad = (len(negative) - len(positive))
positive = [model.special_tokens["pad"]] * pos_pad + positive
paddings = [pos_pad, neg_pad]
ids = [positive, negative]
else:
paddings = []
ids = [positive]
return sample_manual_loop_no_classes(model, ids, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):

View File

@@ -33,8 +33,6 @@ class AnimaTokenizer:
def state_dict(self):
return {}
def decode(self, token_ids, **kwargs):
return self.qwen3_06b.decode(token_ids, **kwargs)
class Qwen3_06BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):

View File

@@ -3,8 +3,6 @@ import torch.nn as nn
from dataclasses import dataclass
from typing import Optional, Any, Tuple
import math
from tqdm import tqdm
import comfy.utils
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
@@ -105,7 +103,6 @@ class Qwen3_06BConfig:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass
class Qwen3_06B_ACE15_Config:
@@ -129,7 +126,6 @@ class Qwen3_06B_ACE15_Config:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass
class Qwen3_2B_ACE15_lm_Config:
@@ -153,7 +149,6 @@ class Qwen3_2B_ACE15_lm_Config:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass
class Qwen3_4B_ACE15_lm_Config:
@@ -177,7 +172,6 @@ class Qwen3_4B_ACE15_lm_Config:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass
class Qwen3_4BConfig:
@@ -201,7 +195,6 @@ class Qwen3_4BConfig:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass
class Qwen3_8BConfig:
@@ -225,7 +218,6 @@ class Qwen3_8BConfig:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass
class Ovis25_2BConfig:
@@ -296,7 +288,6 @@ class Gemma2_2B_Config:
rope_scale = None
final_norm: bool = True
lm_head: bool = False
stop_tokens = [1]
@dataclass
class Gemma3_4B_Config:
@@ -321,14 +312,6 @@ class Gemma3_4B_Config:
rope_scale = [8.0, 1.0]
final_norm: bool = True
lm_head: bool = False
stop_tokens = [1, 106]
GEMMA3_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
@dataclass
class Gemma3_4B_Vision_Config(Gemma3_4B_Config):
vision_config = GEMMA3_VISION_CONFIG
mm_tokens_per_image = 256
@dataclass
class Gemma3_12B_Config:
@@ -353,9 +336,8 @@ class Gemma3_12B_Config:
rope_scale = [8.0, 1.0]
final_norm: bool = True
lm_head: bool = False
vision_config = GEMMA3_VISION_CONFIG
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
mm_tokens_per_image = 256
stop_tokens = [1, 106]
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@@ -373,6 +355,13 @@ class RMSNorm(nn.Module):
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
if not isinstance(theta, list):
theta = [theta]
@@ -401,30 +390,20 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
sin_split = sin.shape[-1] // 2
out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))
out.append((cos, sin))
if len(out) == 1:
return out[0]
return out
def apply_rope(xq, xk, freqs_cis):
org_dtype = xq.dtype
cos = freqs_cis[0]
sin = freqs_cis[1]
nsin = freqs_cis[2]
q_embed = (xq * cos)
q_split = q_embed.shape[-1] // 2
q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin)
q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin)
k_embed = (xk * cos)
k_split = k_embed.shape[-1] // 2
k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin)
k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin)
q_embed = (xq * cos) + (rotate_half(xq) * sin)
k_embed = (xk * cos) + (rotate_half(xk) * sin)
return q_embed.to(org_dtype), k_embed.to(org_dtype)
@@ -459,10 +438,8 @@ class Attention(nn.Module):
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
sliding_window: Optional[int] = None,
):
batch_size, seq_length, _ = hidden_states.shape
xq = self.q_proj(hidden_states)
xk = self.k_proj(hidden_states)
xv = self.v_proj(hidden_states)
@@ -497,11 +474,6 @@ class Attention(nn.Module):
else:
present_key_value = (xk, xv, index + num_tokens)
if sliding_window is not None and xk.shape[2] > sliding_window:
xk = xk[:, :, -sliding_window:]
xv = xv[:, :, -sliding_window:]
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
@@ -584,12 +556,10 @@ class TransformerBlockGemma2(nn.Module):
optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
sliding_window = None
if self.transformer_type == 'gemma3':
if self.sliding_attention:
sliding_window = self.sliding_attention
if x.shape[1] > self.sliding_attention:
sliding_mask = torch.full((x.shape[1], x.shape[1]), torch.finfo(x.dtype).min, device=x.device, dtype=x.dtype)
sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
sliding_mask.tril_(diagonal=-self.sliding_attention)
if attention_mask is not None:
attention_mask = attention_mask + sliding_mask
@@ -608,7 +578,6 @@ class TransformerBlockGemma2(nn.Module):
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
past_key_value=past_key_value,
sliding_window=sliding_window,
)
x = self.post_attention_layernorm(x)
@@ -793,107 +762,6 @@ class BaseLlama:
def forward(self, input_ids, *args, **kwargs):
return self.model(input_ids, *args, **kwargs)
class BaseGenerate:
def logits(self, x):
input = x[:, -1:]
if hasattr(self.model, "lm_head"):
module = self.model.lm_head
else:
module = self.model.embed_tokens
offload_stream = None
if module.comfy_cast_weights:
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
else:
weight = self.model.embed_tokens.weight.to(x)
x = torch.nn.functional.linear(input, weight, None)
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
return x
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0):
device = embeds.device
model_config = self.model.config
if stop_tokens is None:
stop_tokens = self.model.config.stop_tokens
if execution_dtype is None:
if comfy.model_management.should_use_bf16(device):
execution_dtype = torch.bfloat16
else:
execution_dtype = torch.float32
embeds = embeds.to(execution_dtype)
if embeds.ndim == 2:
embeds = embeds.unsqueeze(0)
past_key_values = [] #kv_cache init
max_cache_len = embeds.shape[1] + max_length
for x in range(model_config.num_hidden_layers):
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
generated_token_ids = []
pbar = comfy.utils.ProgressBar(max_length)
# Generation loop
for step in tqdm(range(max_length), desc="Generating tokens"):
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
logits = self.logits(x)[:, -1]
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
token_id = next_token[0].item()
generated_token_ids.append(token_id)
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
pbar.update(1)
if token_id in stop_tokens:
break
return generated_token_ids
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
if not do_sample or temperature == 0.0:
return torch.argmax(logits, dim=-1, keepdim=True)
# Sampling mode
if repetition_penalty != 1.0:
for i in range(logits.shape[0]):
for token_id in set(token_history):
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = torch.finfo(logits.dtype).min
if min_p > 0.0:
probs_before_filter = torch.nn.functional.softmax(logits, dim=-1)
top_probs, _ = probs_before_filter.max(dim=-1, keepdim=True)
min_threshold = min_p * top_probs
indices_to_remove = probs_before_filter < min_threshold
logits[indices_to_remove] = torch.finfo(logits.dtype).min
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 0] = False
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = torch.finfo(logits.dtype).min
probs = torch.nn.functional.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1, generator=generator)
class BaseQwen3:
def logits(self, x):
input = x[:, -1:]
@@ -937,7 +805,7 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_06B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
class Qwen3_06B(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_06BConfig(**config_dict)
@@ -964,7 +832,7 @@ class Qwen3_2B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_4B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
class Qwen3_4B(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_4BConfig(**config_dict)
@@ -982,7 +850,7 @@ class Qwen3_4B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_8B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
class Qwen3_8B(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_8BConfig(**config_dict)
@@ -1000,7 +868,7 @@ class Ovis25_2B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen25_7BVLI_Config(**config_dict)
@@ -1010,9 +878,6 @@ class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
# todo: should this be tied or not?
#self.lm_head = operations.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def preprocess_embed(self, embed, device):
if embed["type"] == "image":
image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
@@ -1046,7 +911,7 @@ class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids)
class Gemma2_2B(BaseLlama, BaseGenerate, torch.nn.Module):
class Gemma2_2B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma2_2B_Config(**config_dict)
@@ -1055,7 +920,7 @@ class Gemma2_2B(BaseLlama, BaseGenerate, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Gemma3_4B(BaseLlama, BaseGenerate, torch.nn.Module):
class Gemma3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_4B_Config(**config_dict)
@@ -1064,25 +929,7 @@ class Gemma3_4B(BaseLlama, BaseGenerate, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Gemma3_4B_Vision(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_4B_Vision_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations)
self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations)
self.image_size = config.vision_config["image_size"]
def preprocess_embed(self, embed, device):
if embed["type"] == "image":
image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True)
return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None
return None, None
class Gemma3_12B(BaseLlama, BaseGenerate, torch.nn.Module):
class Gemma3_12B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_12B_Config(**config_dict)

View File

@@ -3,9 +3,9 @@ import os
from transformers import T5TokenizerFast
from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.genmo
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import torch
import comfy.utils
import math
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -22,86 +22,46 @@ def ltxv_te(*args, **kwargs):
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
class Gemma3_Tokenizer():
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
def tokenize_with_weights(self, text, return_word_ids=False, image=None, llama_template=None, skip_template=True, **kwargs):
self.llama_template = "<start_of_turn>system\nYou are a helpful assistant.<end_of_turn>\n<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n"
self.llama_template_images = "<start_of_turn>system\nYou are a helpful assistant.<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>{}<end_of_turn>\n\n<start_of_turn>model\n"
if image is None:
images = []
else:
samples = image.movedim(-1, 1)
total = int(896 * 896)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1)
images = [s[:, :, :, :3]]
if text.startswith('<start_of_turn>'):
skip_template = True
if skip_template:
llama_text = text
else:
if llama_template is None:
if len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
text_tokens = super().tokenize_with_weights(llama_text, return_word_ids)
if len(images) > 0:
embed_count = 0
for r in text_tokens:
for i, token in enumerate(r):
if token[0] == 262144 and embed_count < len(images):
r[i] = ({"type": "image", "data": images[embed_count]},) + token[1:]
embed_count += 1
return text_tokens
class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer)
class Gemma3_12BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
self.dtypes = set()
self.dtypes.add(dtype)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
tokens_only = [[t[0] for t in b] for b in tokens]
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs):
text = llama_template.format(text)
text_tokens = super().tokenize_with_weights(text, return_word_ids)
embed_count = 0
for k in text_tokens:
tt = text_tokens[k]
for r in tt:
for i in range(len(r)):
if r[i][0] == 262144:
if image_embeds is not None and embed_count < image_embeds.shape[0]:
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:]
embed_count += 1
return text_tokens
class LTXAVTEModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
super().__init__()
self.dtypes = set()
self.dtypes.add(dtype)
self.compat_mode = False
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
self.dtypes.add(dtype_llama)
@@ -109,11 +69,6 @@ class LTXAVTEModel(torch.nn.Module):
operations = self.gemma3_12b.operations # TODO
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
def enable_compat_mode(self): # TODO: remove
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
operations = self.gemma3_12b.operations
dtype = self.text_embedding_projection.weight.dtype
device = self.text_embedding_projection.weight.device
self.audio_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
@@ -129,7 +84,6 @@ class LTXAVTEModel(torch.nn.Module):
device=device,
operations=operations,
)
self.compat_mode = True
def set_clip_options(self, options):
self.execution_device = options.get("execution_device", self.execution_device)
@@ -143,7 +97,6 @@ class LTXAVTEModel(torch.nn.Module):
token_weight_pairs = token_weight_pairs["gemma3_12b"]
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
out = out[:, :, -torch.sum(extra["attention_mask"]).item():]
out_device = out.device
if comfy.model_management.should_use_bf16(self.execution_device):
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
@@ -152,45 +105,30 @@ class LTXAVTEModel(torch.nn.Module):
out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out)
out = out.float()
if self.compat_mode:
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)
return out.to(out_device), pooled
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
def load_sd(self, sd):
if "model.layers.47.self_attn.q_norm.weight" in sd:
return self.gemma3_12b.load_sd(sd)
else:
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True)
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
if len(sdo) == 0:
sdo = sd
missing_all = []
unexpected_all = []
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection)]:
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]:
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
if component_sd:
missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False))
missing_all.extend([f"{prefix}{k}" for k in missing])
unexpected_all.extend([f"{prefix}{k}" for k in unexpected])
if "model.diffusion_model.audio_embeddings_connector.transformer_1d_blocks.2.attn1.to_q.bias" not in sd: # TODO: remove
ww = sd.get("model.diffusion_model.audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_q.bias", None)
if ww is not None:
if ww.shape[0] == 3840:
self.enable_compat_mode()
sdv = comfy.utils.state_dict_prefix_replace(sd, {"model.diffusion_model.video_embeddings_connector.": ""}, filter_keys=True)
self.video_embeddings_connector.load_state_dict(sdv, strict=False, assign=getattr(self, "can_assign_sd", False))
sda = comfy.utils.state_dict_prefix_replace(sd, {"model.diffusion_model.audio_embeddings_connector.": ""}, filter_keys=True)
self.audio_embeddings_connector.load_state_dict(sda, strict=False, assign=getattr(self, "can_assign_sd", False))
return (missing_all, unexpected_all)
def memory_estimation_function(self, token_weight_pairs, device=None):
@@ -200,7 +138,6 @@ class LTXAVTEModel(torch.nn.Module):
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
num_tokens = max(num_tokens, 64)
return num_tokens * constant * 1024 * 1024
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
@@ -213,14 +150,3 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
dtype = dtype_llama
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return LTXAVTEModel_
def gemma3_te(dtype_llama=None, llama_quantization_metadata=None):
class Gemma3_12BModel_(Gemma3_12BModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Gemma3_12BModel_

View File

@@ -1,23 +1,23 @@
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.llama
from comfy.text_encoders.lt import Gemma3_Tokenizer
import comfy.utils
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
special_tokens = {"<end_of_turn>": 107}
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class Gemma3_4BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, disable_weights=True, tokenizer_data=tokenizer_data)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -40,20 +40,6 @@ class Gemma3_4BModel(sd1_clip.SDClipModel):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def process_tokens(self, tokens, device):
embeds, _, _, embeds_info = super().process_tokens(tokens, device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
return embeds
class LuminaModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
@@ -64,8 +50,6 @@ def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b
model = Gemma2_2BModel
elif model_type == "gemma3_4b":
model = Gemma3_4BModel
elif model_type == "gemma3_4b_vision":
model = Gemma3_4B_Vision_Model
class LuminaTEModel_(LuminaModel):
def __init__(self, device="cpu", dtype=None, model_options={}):

View File

@@ -6,10 +6,9 @@ class SPieceTokenizer:
def from_pretrained(path, **kwargs):
return SPieceTokenizer(path, **kwargs)
def __init__(self, tokenizer_path, add_bos=False, add_eos=True, special_tokens=None):
def __init__(self, tokenizer_path, add_bos=False, add_eos=True):
self.add_bos = add_bos
self.add_eos = add_eos
self.special_tokens = special_tokens
import sentencepiece
if torch.is_tensor(tokenizer_path):
tokenizer_path = tokenizer_path.numpy().tobytes()
@@ -28,32 +27,8 @@ class SPieceTokenizer:
return out
def __call__(self, string):
if self.special_tokens is not None:
import re
special_tokens_pattern = '|'.join(re.escape(token) for token in self.special_tokens.keys())
if special_tokens_pattern and re.search(special_tokens_pattern, string):
parts = re.split(f'({special_tokens_pattern})', string)
result = []
for part in parts:
if not part:
continue
if part in self.special_tokens:
result.append(self.special_tokens[part])
else:
encoded = self.tokenizer.encode(part, add_bos=False, add_eos=False)
result.extend(encoded)
return {"input_ids": result}
out = self.tokenizer.encode(string)
return {"input_ids": out}
def decode(self, token_ids, skip_special_tokens=False):
if skip_special_tokens and self.special_tokens:
special_token_ids = set(self.special_tokens.values())
token_ids = [tid for tid in token_ids if tid not in special_token_ids]
return self.tokenizer.decode(token_ids)
def serialize_model(self):
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))

View File

@@ -20,7 +20,7 @@
import torch
import math
import struct
import comfy.memory_management
import comfy.checkpoint_pickle
import safetensors.torch
import numpy as np
from PIL import Image
@@ -29,7 +29,7 @@ import itertools
from torch.nn.functional import interpolate
from tqdm.auto import trange
from einops import rearrange
from comfy.cli_args import args
from comfy.cli_args import args, enables_dynamic_vram
import json
import time
import mmap
@@ -38,26 +38,26 @@ import warnings
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
if True: # ckpt/pt file whitelist for safe loading of old sd files
ALWAYS_SAFE_LOAD = False
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
class ModelCheckpoint:
pass
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
def scalar(*args, **kwargs):
return None
from numpy.core.multiarray import scalar as sc
return sc(*args, **kwargs)
scalar.__module__ = "numpy.core.multiarray"
from numpy import dtype
from numpy.dtypes import Float64DType
def encode(*args, **kwargs): # no longer necessary on newer torch
return None
encode.__module__ = "_codecs"
from _codecs import encode
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
ALWAYS_SAFE_LOAD = True
logging.info("Checkpoint files will always be loaded safely.")
else:
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
# Current as of safetensors 0.7.0
_TYPES = {
@@ -113,7 +113,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try:
if comfy.memory_management.aimdo_enabled:
if enables_dynamic_vram():
sd, metadata = load_safetensors(ckpt)
if not return_metadata:
metadata = None
@@ -140,8 +140,11 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if MMAP_TORCH_FILES:
torch_args["mmap"] = True
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
if safe_load or ALWAYS_SAFE_LOAD:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
else:
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
@@ -672,10 +675,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
"ff_context.linear_in.bias": "txt_mlp.0.bias",
"ff_context.linear_out.weight": "txt_mlp.2.weight",
"ff_context.linear_out.bias": "txt_mlp.2.bias",
"attn.norm_q.weight": "img_attn.norm.query_norm.weight",
"attn.norm_k.weight": "img_attn.norm.key_norm.weight",
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight",
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
}
for k in block_map:
@@ -698,8 +701,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
"norm.linear.bias": "modulation.lin.bias",
"proj_out.weight": "linear2.weight",
"proj_out.bias": "linear2.bias",
"attn.norm_q.weight": "norm.query_norm.weight",
"attn.norm_k.weight": "norm.key_norm.weight",
"attn.norm_q.weight": "norm.query_norm.scale",
"attn.norm_k.weight": "norm.key_norm.scale",
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
"attn.to_out.weight": "linear2.weight", # Flux 2
}
@@ -1154,7 +1157,7 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
def model_trange(*args, **kwargs):
if not comfy.memory_management.aimdo_enabled:
if comfy.memory_management.aimdo_allocator is None:
return trange(*args, **kwargs)
pbar = trange(*args, **kwargs, smoothing=1.0)
@@ -1418,11 +1421,3 @@ def deepcopy_list_dict(obj, memo=None):
memo[obj_id] = res
return res
def normalize_image_embeddings(embeds, embeds_info, scale_factor):
"""Normalize image embeddings to match text embedding scale"""
for info in embeds_info:
if info.get("type") == "image":
start_idx = info["index"]
end_idx = start_idx + info["size"]
embeds[:, start_idx:end_idx, :] /= scale_factor

View File

@@ -49,12 +49,6 @@ class WeightAdapterBase:
"""
raise NotImplementedError
def calculate_shape(
self,
key
):
return None
def calculate_weight(
self,
weight,

View File

@@ -214,13 +214,6 @@ class LoRAAdapter(WeightAdapterBase):
else:
return None
def calculate_shape(
self,
key
):
reshape = self.weights[5]
return tuple(reshape) if reshape is not None else None
def calculate_weight(
self,
weight,

View File

@@ -14,7 +14,6 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
"node_replacements": True,
}

View File

@@ -21,17 +21,6 @@ class ComfyAPI_latest(ComfyAPIBase):
VERSION = "latest"
STABLE = False
def __init__(self):
super().__init__()
self.node_replacement = self.NodeReplacement()
self.execution = self.Execution()
class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: io.NodeReplace) -> None:
"""Register a node replacement mapping."""
from server import PromptServer
PromptServer.instance.node_replace_manager.register(node_replace)
class Execution(ProxiedSingleton):
async def set_progress(
self,
@@ -84,6 +73,8 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display,
)
execution: Execution
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""

View File

@@ -444,7 +444,7 @@ class VideoFromComponents(VideoInput):
output.mux(packet)
if audio_stream and self.__components.audio:
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().contiguous().numpy(), format='fltp', layout=layout)
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().numpy(), format='fltp', layout=layout)
frame.sample_rate = audio_sample_rate
frame.pts = 0
output.mux(audio_stream.encode(frame))

View File

@@ -73,15 +73,8 @@ class RemoteOptions:
class NumberDisplay(str, Enum):
number = "number"
slider = "slider"
gradient_slider = "gradientslider"
class ControlAfterGenerate(str, Enum):
fixed = "fixed"
increment = "increment"
decrement = "decrement"
randomize = "randomize"
class _ComfyType(ABC):
Type = Any
io_type: str = None
@@ -270,7 +263,7 @@ class Int(ComfyTypeIO):
class Input(WidgetInput):
'''Integer input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool | ControlAfterGenerate=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.min = min
@@ -297,15 +290,13 @@ class Float(ComfyTypeIO):
'''Float input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
display_mode: NumberDisplay=None, gradient_stops: list[list[float]]=None,
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.min = min
self.max = max
self.step = step
self.round = round
self.display_mode = display_mode
self.gradient_stops = gradient_stops
self.default: float
def as_dict(self):
@@ -315,7 +306,6 @@ class Float(ComfyTypeIO):
"step": self.step,
"round": self.round,
"display": self.display_mode,
"gradient_stops": self.gradient_stops,
})
@comfytype(io_type="STRING")
@@ -355,7 +345,7 @@ class Combo(ComfyTypeIO):
tooltip: str=None,
lazy: bool=None,
default: str | int | Enum = None,
control_after_generate: bool | ControlAfterGenerate=None,
control_after_generate: bool=None,
upload: UploadType=None,
image_folder: FolderType=None,
remote: RemoteOptions=None,
@@ -399,7 +389,7 @@ class MultiCombo(ComfyTypeI):
Type = list[str]
class Input(Combo.Input):
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool | ControlAfterGenerate=None,
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced)
self.multiselect = True
@@ -1213,30 +1203,6 @@ class Color(ComfyTypeIO):
def as_dict(self):
return super().as_dict()
@comfytype(io_type="BOUNDING_BOX")
class BoundingBox(ComfyTypeIO):
class BoundingBoxDict(TypedDict):
x: int
y: int
width: int
height: int
Type = BoundingBoxDict
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: dict=None, component: str=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless)
self.component = component
if default is None:
self.default = {"x": 0, "y": 0, "width": 512, "height": 512}
def as_dict(self):
d = super().as_dict()
if self.component:
d["component"] = self.component
return d
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func
@@ -1343,7 +1309,6 @@ class NodeInfoV1:
api_node: bool=None
price_badge: dict | None = None
search_aliases: list[str]=None
essentials_category: str=None
@dataclass
@@ -1465,8 +1430,6 @@ class Schema:
"""Flags a node as expandable, allowing NodeOutput to include 'expand' property."""
accept_all_inputs: bool=False
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
essentials_category: str | None = None
"""Optional category for the Essentials tab. Path-based like category field (e.g., 'Basic', 'Image Tools/Editing')."""
def validate(self):
'''Validate the schema:
@@ -1573,7 +1536,6 @@ class Schema:
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
search_aliases=self.search_aliases if self.search_aliases else None,
essentials_category=self.essentials_category,
)
return info
@@ -2068,74 +2030,11 @@ class _UIOutput(ABC):
...
class InputMapOldId(TypedDict):
"""Map an old node input to a new node input by ID."""
new_id: str
old_id: str
class InputMapSetValue(TypedDict):
"""Set a specific value for a new node input."""
new_id: str
set_value: Any
InputMap = InputMapOldId | InputMapSetValue
"""
Input mapping for node replacement. Type is inferred by dictionary keys:
- {"new_id": str, "old_id": str} - maps old input to new input
- {"new_id": str, "set_value": Any} - sets a specific value for new input
"""
class OutputMap(TypedDict):
"""Map outputs of node replacement via indexes."""
new_idx: int
old_idx: int
class NodeReplace:
"""
Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
Also supports assigning specific values to the input widgets of the new node.
Args:
new_node_id: The class name of the new replacement node.
old_node_id: The class name of the deprecated node.
old_widget_ids: Ordered list of input IDs for widgets that may not have an input slot
connected. The workflow JSON stores widget values by their relative position index,
not by ID. This list maps those positional indexes to input IDs, enabling the
replacement system to correctly identify widget values during node migration.
input_mapping: List of input mappings from old node to new node.
output_mapping: List of output mappings from old node to new node.
"""
def __init__(self,
new_node_id: str,
old_node_id: str,
old_widget_ids: list[str] | None=None,
input_mapping: list[InputMap] | None=None,
output_mapping: list[OutputMap] | None=None,
):
self.new_node_id = new_node_id
self.old_node_id = old_node_id
self.old_widget_ids = old_widget_ids
self.input_mapping = input_mapping
self.output_mapping = output_mapping
def as_dict(self):
"""Create serializable representation of the node replacement."""
return {
"new_node_id": self.new_node_id,
"old_node_id": self.old_node_id,
"old_widget_ids": self.old_widget_ids,
"input_mapping": list(self.input_mapping) if self.input_mapping else None,
"output_mapping": list(self.output_mapping) if self.output_mapping else None,
}
__all__ = [
"FolderType",
"UploadType",
"RemoteOptions",
"NumberDisplay",
"ControlAfterGenerate",
"comfytype",
"Custom",
@@ -2222,6 +2121,4 @@ __all__ = [
"ImageCompare",
"PriceBadgeDepends",
"PriceBadge",
"BoundingBox",
"NodeReplace",
]

View File

@@ -45,55 +45,17 @@ class BriaEditImageRequest(BaseModel):
)
class BriaRemoveBackgroundRequest(BaseModel):
image: str = Field(...)
sync: bool = Field(False)
visual_input_content_moderation: bool = Field(
False, description="If true, returns 422 on input image moderation failure."
)
visual_output_content_moderation: bool = Field(
False, description="If true, returns 422 on visual output moderation failure."
)
seed: int = Field(...)
class BriaStatusResponse(BaseModel):
request_id: str = Field(...)
status_url: str = Field(...)
warning: str | None = Field(None)
class BriaRemoveBackgroundResult(BaseModel):
image_url: str = Field(...)
class BriaRemoveBackgroundResponse(BaseModel):
status: str = Field(...)
result: BriaRemoveBackgroundResult | None = Field(None)
class BriaImageEditResult(BaseModel):
class BriaResult(BaseModel):
structured_prompt: str = Field(...)
image_url: str = Field(...)
class BriaImageEditResponse(BaseModel):
class BriaResponse(BaseModel):
status: str = Field(...)
result: BriaImageEditResult | None = Field(None)
class BriaRemoveVideoBackgroundRequest(BaseModel):
video: str = Field(...)
background_color: str = Field(default="transparent", description="Background color for the output video.")
output_container_and_codec: str = Field(...)
preserve_audio: bool = Field(True)
seed: int = Field(...)
class BriaRemoveVideoBackgroundResult(BaseModel):
video_url: str = Field(...)
class BriaRemoveVideoBackgroundResponse(BaseModel):
status: str = Field(...)
result: BriaRemoveVideoBackgroundResult | None = Field(None)
result: BriaResult | None = Field(None)

View File

@@ -27,7 +27,6 @@ class Seedream4TaskCreationRequest(BaseModel):
sequential_image_generation: str = Field("disabled")
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
watermark: bool = Field(False)
output_format: str | None = None
class ImageTaskCreationResponse(BaseModel):
@@ -107,7 +106,6 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [
("2496x1664 (3:2)", 2496, 1664),
("1664x2496 (2:3)", 1664, 2496),
("3024x1296 (21:9)", 3024, 1296),
("3072x3072 (1:1)", 3072, 3072),
("4096x4096 (1:1)", 4096, 4096),
("Custom", None, None),
]

View File

@@ -1,88 +0,0 @@
from pydantic import BaseModel, Field
class SpeechToTextRequest(BaseModel):
model_id: str = Field(...)
cloud_storage_url: str = Field(...)
language_code: str | None = Field(None, description="ISO-639-1 or ISO-639-3 language code")
tag_audio_events: bool | None = Field(None, description="Annotate sounds like (laughter) in transcript")
num_speakers: int | None = Field(None, description="Max speakers predicted")
timestamps_granularity: str = Field(default="word", description="Timing precision: none, word, or character")
diarize: bool | None = Field(None, description="Annotate which speaker is talking")
diarization_threshold: float | None = Field(None, description="Speaker separation sensitivity")
temperature: float | None = Field(None, description="Randomness control")
seed: int = Field(..., description="Seed for deterministic sampling")
class SpeechToTextWord(BaseModel):
text: str = Field(..., description="The word text")
type: str = Field(default="word", description="Type of text element (word, spacing, etc.)")
start: float | None = Field(None, description="Start time in seconds (when timestamps enabled)")
end: float | None = Field(None, description="End time in seconds (when timestamps enabled)")
speaker_id: str | None = Field(None, description="Speaker identifier when diarization is enabled")
logprob: float | None = Field(None, description="Log probability of the word")
class SpeechToTextResponse(BaseModel):
language_code: str = Field(..., description="Detected or specified language code")
language_probability: float | None = Field(None, description="Confidence of language detection")
text: str = Field(..., description="Full transcript text")
words: list[SpeechToTextWord] | None = Field(None, description="Word-level timing information")
class TextToSpeechVoiceSettings(BaseModel):
stability: float | None = Field(None, description="Voice stability")
similarity_boost: float | None = Field(None, description="Similarity boost")
style: float | None = Field(None, description="Style exaggeration")
use_speaker_boost: bool | None = Field(None, description="Boost similarity to original speaker")
speed: float | None = Field(None, description="Speech speed")
class TextToSpeechRequest(BaseModel):
text: str = Field(..., description="Text to convert to speech")
model_id: str = Field(..., description="Model ID for TTS")
language_code: str | None = Field(None, description="ISO-639-1 or ISO-639-3 language code")
voice_settings: TextToSpeechVoiceSettings | None = Field(None, description="Voice settings")
seed: int = Field(..., description="Seed for deterministic sampling")
apply_text_normalization: str | None = Field(None, description="Text normalization mode: auto, on, off")
class TextToSoundEffectsRequest(BaseModel):
text: str = Field(..., description="Text prompt to convert into a sound effect")
duration_seconds: float = Field(..., description="Duration of generated sound in seconds")
prompt_influence: float = Field(..., description="How closely generation follows the prompt")
loop: bool | None = Field(None, description="Whether to create a smoothly looping sound effect")
class AddVoiceRequest(BaseModel):
name: str = Field(..., description="Name that identifies the voice")
remove_background_noise: bool = Field(..., description="Remove background noise from voice samples")
class AddVoiceResponse(BaseModel):
voice_id: str = Field(..., description="The newly created voice's unique identifier")
class SpeechToSpeechRequest(BaseModel):
model_id: str = Field(..., description="Model ID for speech-to-speech")
voice_settings: str = Field(..., description="JSON string of voice settings")
seed: int = Field(..., description="Seed for deterministic sampling")
remove_background_noise: bool = Field(..., description="Remove background noise from input audio")
class DialogueInput(BaseModel):
text: str = Field(..., description="Text content to convert to speech")
voice_id: str = Field(..., description="Voice identifier for this dialogue segment")
class DialogueSettings(BaseModel):
stability: float | None = Field(None, description="Voice stability (0-1)")
class TextToDialogueRequest(BaseModel):
inputs: list[DialogueInput] = Field(..., description="List of dialogue segments")
model_id: str = Field(..., description="Model ID for dialogue generation")
language_code: str | None = Field(None, description="ISO-639-1 language code")
settings: DialogueSettings | None = Field(None, description="Voice settings")
seed: int | None = Field(None, description="Seed for deterministic sampling")
apply_text_normalization: str | None = Field(None, description="Text normalization mode: auto, on, off")

View File

@@ -116,15 +116,9 @@ class GeminiGenerationConfig(BaseModel):
topP: float | None = Field(None, ge=0.0, le=1.0)
class GeminiImageOutputOptions(BaseModel):
mimeType: str = Field("image/png")
compressionQuality: int | None = Field(None)
class GeminiImageConfig(BaseModel):
aspectRatio: str | None = Field(None)
imageSize: str | None = Field(None)
imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions)
class GeminiImageGenerationConfig(GeminiGenerationConfig):

View File

@@ -64,23 +64,3 @@ class To3DProTaskResultResponse(BaseModel):
class To3DProTaskQueryRequest(BaseModel):
JobId: str = Field(...)
class To3DUVFileInput(BaseModel):
Type: str = Field(..., description="File type: GLB, OBJ, or FBX")
Url: str = Field(...)
class To3DUVTaskRequest(BaseModel):
File: To3DUVFileInput = Field(...)
class TextureEditImageInfo(BaseModel):
Url: str = Field(...)
class TextureEditTaskRequest(BaseModel):
File3D: To3DUVFileInput = Field(...)
Image: TextureEditImageInfo | None = Field(None)
Prompt: str | None = Field(None)
EnablePBR: bool | None = Field(None)

View File

@@ -134,13 +134,6 @@ class ImageToVideoWithAudioRequest(BaseModel):
shot_type: str | None = Field(None)
class KlingAvatarRequest(BaseModel):
image: str = Field(...)
sound_file: str = Field(...)
prompt: str | None = Field(None)
mode: str = Field(...)
class MotionControlRequest(BaseModel):
prompt: str = Field(...)
image_url: str = Field(...)

Some files were not shown because too many files have changed in this diff Show More